From d9638712895a6ee53e2da1154d94f16b9318d740 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 15 Apr 2025 15:31:17 +0200 Subject: [PATCH 001/224] fg rule-based detection algo from FARM --- .gitignore | 1 + .../preprocessing/fg_detection/__init__.py | 0 .../preprocessing/fg_detection/rule_based.py | 1478 +++++++++++++++++ 3 files changed, 1479 insertions(+) create mode 100644 .gitignore create mode 100644 chebai_graph/preprocessing/fg_detection/__init__.py create mode 100644 chebai_graph/preprocessing/fg_detection/rule_based.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a09c56d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.idea diff --git a/chebai_graph/preprocessing/fg_detection/__init__.py b/chebai_graph/preprocessing/fg_detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai_graph/preprocessing/fg_detection/rule_based.py b/chebai_graph/preprocessing/fg_detection/rule_based.py new file mode 100644 index 0000000..4c7236a --- /dev/null +++ b/chebai_graph/preprocessing/fg_detection/rule_based.py @@ -0,0 +1,1478 @@ +# This code file is taken from https://github.com/thaonguyen217/farm_molecular_representation +# Reference Paper: Nguyen, Thao, et al. "FARM: Functional Group-Aware Representations for Small Molecules." +# arXiv preprint arXiv:2410.02082 (2024). + +import re +from rdkit import Chem +from rdkit.Chem import MolToSmiles as m2s +from rdkit.Chem import MolFromSmiles as s2m +from rdkit.Chem import FragmentOnBonds +from rdkit.Chem.Scaffolds import MurckoScaffold +from rdkit.Chem import AllChem +from copy import deepcopy + +electronegativity = { + 'H': 2.2, + 'LI': 0.98, + 'BE': 1.57, + 'B': 2.04, + 'C': 2.55, + 'N': 3.04, + 'O': 3.44, + 'F': 3.98, + 'NA': 0.93, + 'MG': 1.31, + 'AL': 1.61, + 'SI': 1.9, + 'P': 2.19, + 'S': 2.58, + 'CL': 3.16, + 'K': 0.82, + 'CA': 1.0, + 'SC': 1.36, + 'TI': 1.54, + 'V': 1.63, + 'CR': 1.66, + 'MN': 1.55, + 'FE': 1.83, + 'CO': 1.88, + 'NI': 1.91, + 'CU': 1.9, + 'ZN': 1.65, + 'GA': 1.81, + 'GE': 2.01, + 'AS': 2.18, + 'SE': 2.55, + 'BR': 2.96, + 'RB': 0.82, + 'SR': 0.95, + 'Y': 1.22, + 'ZR': 1.33, + 'NB': 1.6, + 'MO': 2.16, + 'TC': 1.9, + 'RU': 2.2, + 'RH': 2.28, + 'PD': 2.2, + 'AG': 1.93, + 'CD': 1.69, + 'IN': 1.78, + 'SN': 1.96, + 'SB': 2.05, + 'TE': 2.1, + 'I': 2.66, + 'CS': 0.79, + 'BA': 0.89, + 'LA': 1.1, + 'CE': 1.12, + 'PR': 1.13, + 'ND': 1.14, + 'PM': 1.13, + 'SM': 1.17, + 'EU': 1.2, + 'GD': 1.2, + 'TB': 1.1, + 'DY': 1.22, + 'HO': 1.23, + 'ER': 1.24, + 'TM': 1.25, + 'YB': 1.1, + 'LU': 1.27, + 'HF': 1.3, + 'TA': 1.5, + 'W': 2.36, + 'RE': 1.9, + 'OS': 2.2, + 'IR': 2.2, + 'PT': 2.28, + 'AU': 2.54, + 'HG': 2.0, + 'TL': 1.62, + 'PB': 2.33, + 'BI': 2.02, + 'PO': 2.0, + 'AT': 2.2, + 'FR': 0.7, + 'RA': 0.9, + 'AC': 1.1, + 'TH': 1.3, + 'PA': 1.5, + 'U': 1.38, + 'NP': 1.36, + 'PU': 1.28, + 'AM': 1.3, + 'CM': 1.3, + 'BK': 1.3, + 'CF': 1.3, + 'ES': 1.3, + 'FM': 1.3, + 'MD': 1.3, + 'NO': 1.3, + 'LR': 1.3 +} + +def sdf2smiles(sdf_file): + SMILES = set() + supplier = Chem.SDMolSupplier(sdf_file) + for mol in supplier: + if mol is not None: + SMILES.add(Chem.MolToSmiles(mol)) + return SMILES + +def ring_size_processing(ring_size): + if ring_size[0] > ring_size[-1]: + return list(reversed(ring_size)) + else: + return ring_size + +# Function to find all rings connected to a given ring +def find_connected_rings(ring, remaining_rings): + connected_rings = [ring] + merged = True + while merged: + merged = False + for other_ring in remaining_rings: + if ring & other_ring: # If there is a shared atom, they are connected + connected_rings.append(other_ring) + remaining_rings.remove(other_ring) + ring = ring.union(other_ring) + merged = True + return connected_rings + +def detect_functional_group(mol): # type: ignore + AllChem.GetSymmSSSR(mol) # type: ignore + ELEMENTS = set([ + 'Ac', 'Ag', 'Al', 'Am', 'As', 'At', 'Au', 'B', 'Ba', 'Be', 'Bi', 'Bk', 'Br', + 'Ca', 'Cd', 'Ce', 'Cf', 'Cl', 'Cm', 'Co', 'Cr', 'Cs', 'Cu', 'Dy', 'Er', + 'Es', 'Eu', 'F', 'Fe', 'Fm', 'Fr', 'Ga', 'Gd', 'Ge', 'He', 'Hf', 'Hg', + 'Ho', 'I', 'In', 'Ir', 'K', 'Kr', 'La', 'Li', 'Lr', 'Lu', 'Md', 'Mg', 'Mn', + 'Mo', 'N', 'Na', 'Nb', 'Nd', 'Ne', 'Ni', 'Np', 'O', 'Os', 'P', 'Pa', 'Pb', + 'Pd', 'Pm', 'Po', 'Pr', 'Pt', 'Pu', 'Ra', 'Rb', 'Re', 'Rh', 'Rn', 'Ru', 'S', + 'Sb', 'Sc', 'Se', 'Si', 'Sm', 'Sn', 'Sr', 'Ta', 'Tb', 'Tc', 'Te', 'Th', 'Ti', + 'Tl', 'Tm', 'U', 'V', 'W', 'Xe', 'Y', 'Yb', 'Zn', 'Zr']) + + if mol is not None: + for atom in mol.GetAtoms(): + atom.SetProp('FG', '') + atom.SetProp('RING', '') + + ######## SET RING PROP ######## + # Get ring information + ring_info = mol.GetRingInfo() + + if ring_info.NumRings() > 0: + # Get list of atom rings + atom_rings = ring_info.AtomRings() + + # Initialize a list to hold fused ring blocks and their sizes + fused_ring_blocks = [] + ring_sizes = [] + + # Set of rings to process + remaining_rings = [set(ring) for ring in atom_rings] + + # Process each ring block + while remaining_rings: + ring = remaining_rings.pop(0) + connected_rings = find_connected_rings(ring, remaining_rings) + + # Merge all connected rings into one fused block + fused_block = set().union(*connected_rings) + fused_ring_blocks.append(sorted(fused_block)) + ring_sizes.append([len(r) for r in connected_rings]) + + # Display the fused ring blocks and their ring sizes + for i, block in enumerate(fused_ring_blocks): + rs = '-'.join(str(size) for size in ring_size_processing(ring_sizes[i])) + for idx in block: + atom = mol.GetAtomWithIdx(idx) + atom.SetProp('RING', rs) + + ######## SET FUNCTIONAL GROUP PROP ######## + for atom in mol.GetAtoms(): + atom_symbol = atom.GetSymbol() + atom_neighbors = atom.GetNeighbors() + atom_num_neighbors = len(atom_neighbors) + num_H = atom.GetTotalNumHs() + in_ring = atom.IsInRing() + atom_idx = atom.GetIdx() + charge = atom.GetFormalCharge() + + ########################### Groups containing oxygen ########################### + if atom_symbol in ['C', '*'] and charge == 0: # and atom.GetProp('FG') == '': + num_O, num_X, num_C, num_N, num_S = 0, 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['F', 'Cl', 'Br', 'I']: + num_X += 1 + if neighbor.GetSymbol() == 'O': + num_O += 1 + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if neighbor.GetSymbol() == 'N': + num_N += 1 + if neighbor.GetSymbol() == 'S': + num_S += 1 + + if num_H == 1 and atom_num_neighbors == 3 and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'tertiary_carbon') + if atom_num_neighbors == 4 and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'quaternary_carbon') + if num_H == 0 and atom_num_neighbors == 3 and charge == 0 and atom.GetProp('FG') == '' and not in_ring: + atom.SetProp('FG', 'alkene_carbon') + + if num_O == 1 and atom_symbol == 'C' and atom.GetProp('FG') not in ['hemiacetal', 'hemiketal', 'acetal', 'ketal', 'orthoester', 'orthocarbonate_ester', 'carbonate_ester']: + if num_N == 1: # Cyanate and Isocyanate + condition1, condition2 = False, False + condition3, condition4= False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition2 = True + + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition3 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + condition4 = True + + if condition1 and condition2 and not in_ring: # Cyanate + atom.SetProp('FG', 'cyanate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'cyanate') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() in ['C', '*'] and C_neighbor.GetIdx() != atom_idx: + C_neighbor.SetProp('FG', '') + + if condition3 and condition4 and not in_ring: # Isocyanate + atom.SetProp('FG', 'isocyanate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'isocyanate') + + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + bond = mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()) + bondtype = bond.GetBondType() + if bondtype == Chem.BondType.SINGLE: # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] + if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] + neighbor.SetProp('FG', 'hydroxyl') + else: + for O_neighbor in neighbor.GetNeighbors(): + # if not O_neighbor.IsInRing(): + if O_neighbor.GetIdx() != atom_idx and O_neighbor.GetSymbol() in ['C', '*'] and neighbor.GetProp('FG') == '': # Ether [COC] + neighbor.SetProp('FG', 'ether') + if O_neighbor.GetSymbol() == 'O': + if O_neighbor.GetTotalNumHs() == 1: # Hydroperoxy [C-O-O-H] + neighbor.SetProp('FG', 'hydroperoxy') + O_neighbor.SetProp('FG', 'hydroperoxy') + else: + neighbor.SetProp('FG', 'peroxy') + O_neighbor.SetProp('FG', 'peroxy') + + if bondtype == Chem.BondType.DOUBLE: # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] + if num_X == 1 and not neighbor.IsInRing(): # Acyl halide [C(=O)X] + atom.SetProp('FG', 'haloformyl') + for neighbor_ in atom_neighbors: + if neighbor_.GetSymbol() in ['O', 'F', 'Cl', 'Br', 'I']: + neighbor_.SetProp('FG', 'haloformyl') + + if (num_C == 1 and num_H == 1) or num_H == 2 and not in_ring: # Aldehyde [C(=O)H] + atom.SetProp('FG', 'aldehyde') + neighbor.SetProp('FG', 'aldehyde') + + if atom_num_neighbors == 3 and atom.GetProp('FG') not in ['haloformyl', 'amide']: # Ketone [C(=0)C] + atom.SetProp('FG', 'ketone') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and not neighbor.IsInRing(): + neighbor.SetProp('FG', 'ketone') + + if num_O == 2: # and atom.GetProp('FG') == '': + if atom_num_neighbors == 3: + if num_H == 0: + condition1, condition2, condition3, condition4 = False, False, False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1 and not neighbor.IsInRing(): + condition2 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + condition3 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 and atom.GetProp('FG') != 'carbamate': + condition4 = True + + if condition1 and condition2: + atom.SetProp('FG', 'carboxylate') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'carboxylate') + if condition1 and condition3: + atom.SetProp('FG', 'carboxyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'carboxyl') + if condition1 and condition4 and atom.GetProp('FG') not in ['carbamate', 'carbonate_ester']: + atom.SetProp('FG', 'ester') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'ester') + for O_neighbor in neighbor.GetNeighbors(): + O_neighbor.SetProp('FG', 'ester') + + if num_H == 1 and not in_ring: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + condition2 = True + cnt += 1 + + if condition1 and condition2: + atom.SetProp('FG', 'hemiacetal') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'hemiacetal') + if cnt == 2: + atom.SetProp('FG', 'acetal') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'acetal') + + if atom_num_neighbors == 4 and not in_ring: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + condition2 = True + cnt += 1 + + if condition1 and condition2: + atom.SetProp('FG', 'hemiketal') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'hemiketal') + if cnt == 2: + atom.SetProp('FG', 'ketal') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'ketal') + + if num_O == 3 and atom_num_neighbors == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + n_C += 1 + if n_C == 3: + atom.SetProp('FG', 'orthoester') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'orthoester') + + if num_O == 3 and atom_num_neighbors == 3 and charge == 0 and not in_ring: + condition1 = False + n_O = 0 + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + n_O += 1 + if condition1 and n_O == 2: + atom.SetProp('FG', 'carbonate_ester') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'carbonate_ester') + + if num_O == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + n_C += 1 + if n_C == 4: + atom.SetProp('FG', 'orthocarbonate_ester') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'orthocarbonate_ester') + + ########################### Groups containing nitrogen ########################### + #### Amidine #### + if num_N == 2 and atom_num_neighbors == 3: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + condition2 = True + if condition1 and condition2: + atom.SetProp('FG', 'amidine') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'N': + neighbor.SetProp('FG', 'amidine') + + if num_N == 1 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 2: + condition2 = True + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 3 and not neighbor.IsInRing(): + condition3 = True + if condition1 and condition2 and condition3: + atom.SetProp('FG', 'carbamate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'carbamate') + + if num_N == 1 and num_S == 1: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + condition2 = True + if condition1 and condition2: + atom.SetProp('FG', 'isothiocyanate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'isothiocyanate') + + if num_S == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + atom.SetProp('FG', 'thioketone') + neighbor.SetProp('FG', 'thioketone') + + if num_S == 1 and num_H == 1 and atom_num_neighbors == 2: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + atom.SetProp('FG', 'thial') + neighbor.SetProp('FG', 'thial') + + if num_S == 1 and num_O == 1 and atom_num_neighbors == 3: + condition1, condition2 = False, False + condition3, condition4 = False, False + condition5, condition6 = False, False + condition7, condition8 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and not neighbor.IsInRing(): + condition2 = True + + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + condition3 = True + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and not len(neighbor.GetNeighbors())==1: + condition4 = True + + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition5 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and not neighbor.IsInRing(): + condition6 = True + + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + condition7 = True + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors())==1 and not neighbor.IsInRing(): + condition8 = True + + if condition1 and condition2: + atom.SetProp('FG', 'carbothioic_S-acid') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['S', 'O']: + neighbor.SetProp('FG', 'carbothioic_S-acid') + if condition3 and condition4: + atom.SetProp('FG', 'carbothioic_O-acid') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['S', 'O']: + neighbor.SetProp('FG', 'carbothioic_O-acid') + if condition5 and condition6: + atom.SetProp('FG', 'thiolester') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['S', 'O']: + neighbor.SetProp('FG', 'thiolester') + if condition7 and condition8: + atom.SetProp('FG', 'thionoester') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['S', 'O']: + neighbor.SetProp('FG', 'thionoester') + + + if num_S == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and len(neighbor.GetNeighbors()) == 1 and not neighbor.IsInRing(): + condition1 = True + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors()) == 1 and not neighbor.IsInRing(): + condition2 = True + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors()) == 2 and not neighbor.IsInRing(): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition3 = True + + if condition1 and condition2: + atom.SetProp('FG', 'carbodithioic_acid') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S': + neighbor.SetProp('FG', 'carbodithioic_acid') + if condition3 and condition2: + atom.SetProp('FG', 'carbodithio') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S': + neighbor.SetProp('FG', 'carbodithio') + + if num_X == 3 and charge == 0 and atom_num_neighbors == 4: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'F': + num_F += 1 + if neighbor.GetSymbol() == 'Cl': + num_Cl += 1 + if neighbor.GetSymbol() == 'Br': + num_Br += 1 + if neighbor.GetSymbol() == 'I': + num_I += 1 + if num_F == 3: + atom.SetProp('FG', 'trifluoromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'F': + neighbor.SetProp('FG', 'trifluoromethyl') + if num_F == 2 and num_Cl == 1: + atom.SetProp('FG', 'difluorochloromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['F', 'Cl']: + neighbor.SetProp('FG', 'difluorochloromethyl') + if num_F == 2 and num_Br == 1: + atom.SetProp('FG', 'bromodifluoromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['F', 'Br']: + neighbor.SetProp('FG', 'bromodifluoromethyl') + + if num_Cl == 3: + atom.SetProp('FG', 'trichloromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'Cl': + neighbor.SetProp('FG', 'trichloromethyl') + if num_Cl == 2 and num_Br == 1: + atom.SetProp('FG', 'bromodichloromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['Cl', 'Br']: + neighbor.SetProp('FG', 'bromodichloromethyl') + + if num_Br == 3: + atom.SetProp('FG', 'tribromomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'Br': + neighbor.SetProp('FG', 'tribromomethyl') + if num_Br == 2 and num_F == 1: + atom.SetProp('FG', 'dibromofluoromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['F', 'Br']: + neighbor.SetProp('FG', 'dibromofluoromethyl') + + if num_I == 3: + atom.SetProp('FG', 'triiodomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'I': + neighbor.SetProp('FG', 'triiodomethyl') + + if num_X == 2 and charge == 0 and atom_num_neighbors == 3 and num_H == 1: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'F': + num_F += 1 + if neighbor.GetSymbol() == 'Cl': + num_Cl += 1 + if neighbor.GetSymbol() == 'Br': + num_Br += 1 + if neighbor.GetSymbol() == 'I': + num_I += 1 + + if num_F == 2: + atom.SetProp('FG', 'difluoromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'F': + neighbor.SetProp('FG', 'difluoromethyl') + if num_F == 1 and num_Cl == 1: + atom.SetProp('FG', 'fluorochloromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['F', 'Cl']: + neighbor.SetProp('FG', 'fluorochloromethyl') + + if num_Cl == 2: + atom.SetProp('FG', 'dichloromethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'Cl': + neighbor.SetProp('FG', 'dichloromethyl') + if num_Cl == 1 and num_Br == 1: + atom.SetProp('FG', 'chlorobromomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['Cl', 'Br']: + neighbor.SetProp('FG', 'chlorobromomethyl') + if num_Cl == 1 and num_I == 1: + atom.SetProp('FG', 'chloroiodomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['Cl', 'I']: + neighbor.SetProp('FG', 'chloroiodomethyl') + + if num_Br == 2: + atom.SetProp('FG', 'dibromomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'Br': + neighbor.SetProp('FG', 'dibromomethyl') + if num_Br == 1 and num_I == 1: + atom.SetProp('FG', 'bromoiodomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['Br', 'I']: + neighbor.SetProp('FG', 'bromoiodomethyl') + + if num_I == 2: + atom.SetProp('FG', 'diiodomethyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'I': + neighbor.SetProp('FG', 'diiodomethyl') + + if (atom_num_neighbors == 2 or atom_num_neighbors == 1) and not in_ring and atom.GetProp('FG') == '': + bonds = atom.GetBonds() + ns, nd, nt = 0, 0, 0 + for bond in bonds: + if bond.GetBondType() == Chem.BondType.SINGLE: + ns += 1 + elif bond.GetBondType() == Chem.BondType.DOUBLE: + nd += 1 + else: + nt += 1 + if ns >= 1 and nd == 0 and nt == 0: + atom.SetProp('FG', 'alkyl') + if nd >= 1: + atom.SetProp('FG', 'alkene') + if nt == 1: + atom.SetProp('FG', 'alkyne') + + elif atom_symbol == 'O' and not in_ring and charge == 0 and num_H == 0: # Carboxylic anhydride [C(CO)O(CO)C] + num_C = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if num_C == 2: + cnt = 0 + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), C_neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3: + cnt += 1 + if cnt == 2: + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'carboxylic_anhydride') + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == 'O': + C_neighbor.SetProp('FG', 'carboxylic_anhydride') + + elif atom_symbol == 'N': # and atom.GetProp('FG') == '': + num_C, num_O, num_N = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if neighbor.GetSymbol() == 'O': + num_O += 1 + if neighbor.GetSymbol() == 'N': + num_N += 1 + + #### Amines #### + if charge == 0 and num_H == 2 and atom_num_neighbors == 1 and atom.GetProp('FG') != 'hydrazone': # Primary amine [RNH2] + atom.SetProp('FG', 'primary_amine') + + if charge == 0 and num_H == 1 and atom_num_neighbors == 2: # Secondary amine [R'R"NH] + atom.SetProp('FG', 'secondary_amine') + + if charge == 0 and atom_num_neighbors == 3 and atom.GetProp('FG') != 'carbamate': + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), C_neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0 and atom.GetProp('FG') != 'imide': + atom.SetProp('FG', 'amide') + neighbor.SetProp('FG', 'amide') + C_neighbor.SetProp('FG', 'amide') + cnt += 1 + C_idx.append(neighbor.GetIdx()) + + if cnt == 2: + for neighbor in atom_neighbors: + if neighbor.GetIdx() in C_idx: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() in ['O', 'N' ]: + neighbor.SetProp('FG', 'imide') + C_neighbor.SetProp('FG', 'imide') + + if atom.GetProp('FG') not in ['imide', 'amide', 'amidine', 'carbamate']: # Tertiary amine [R3N] + atom.SetProp('FG', 'tertiary_amine') + + if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] + atom.SetProp('FG', '4_ammonium_ion') + + if charge == 0 and num_C == 1 and num_N == 1 and num_H == 0 and atom_num_neighbors == 2: # Hydrazone [R'R"CN2H2] + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() == 'N' and neighbor.GetTotalNumHs() == 2 and neighbor.GetFormalCharge() == 0: + condition2 = True + if condition1 and condition2: + atom.SetProp('FG', 'hydrazone') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'hydrazone') + + #### Imine #### + if charge == 0 and num_C == 1 and num_H == 1 and num_N == 0 and atom_num_neighbors == 1: # Primary ketimine [RC(=NH)R'] + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'primary_ketimine') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'primary_ketimine') + + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'primary_aldimine') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'primary_aldimine') + + if charge == 0 and atom_num_neighbors == 1 and atom.GetProp('FG') not in ['thiocyanate', 'cyanate']: # Nitrile + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE: + atom.SetProp('FG', 'nitrile') + + if charge == 0 and num_C >= 1 and atom_num_neighbors == 2 and atom.GetProp('FG') != 'hydrazone': # Secondary ketimine [RC(=NR'')R'] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'secondary_ketimine') + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + neighbor.SetProp('FG', 'secondary_ketimine') + + if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1: + atom.SetProp('FG', 'secondary_aldimine') + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + neighbor.SetProp('FG', 'secondary_aldimine') + + + if charge == 1 and num_N == 2 and atom_num_neighbors == 2: # Azide [RN3] + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetFormalCharge() == 0 and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + condition1 = True + if neighbor.GetFormalCharge() == -1 and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + condition2 = True + if condition1 and condition2 and not in_ring: + atom.SetProp('FG', 'azide') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'azide') + + if charge == 0 and num_N == 1 and atom_num_neighbors == 2 and not in_ring: # Azo [RN2R'] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'azo') + neighbor.SetProp('FG', 'azo') + break + + if charge == 1 and num_O == 3 and atom_num_neighbors == 3: # Nitrate [RONO2] + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1: + condition2 = True + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0: + condition3 = True + + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp('FG', 'nitrate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'nitrate') + + if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetFormalCharge() == -1 and len(neighbor.GetNeighbors()) == 1: + atom.SetProp('FG', 'isonitrile') + neighbor.SetProp('FG', 'isonitrile') + + if charge == 0 and num_O == 2 and atom_num_neighbors == 2 and not in_ring: # Nitrite + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2: + atom.SetProp('FG', 'nitrosooxy') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'nitrosooxy') + + if charge == 1 and num_O == 2 and atom_num_neighbors == 3 and not in_ring: # Nitro compound + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1: + condition2 = True + if condition1 and condition2 and not in_ring: + atom.SetProp('FG', 'nitro') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'nitro') + + if charge == 0 and num_O == 1 and atom_num_neighbors == 2 and not in_ring: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: # Nitroso compound + atom.SetProp('FG', 'nitroso') + neighbor.SetProp('FG', 'nitroso') + + if charge == 0 and num_O == 1 and num_C == 1 and atom_num_neighbors == 2: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1: + condition1 = True + if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + condition2 = True + if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 3: + condition3 = True + + if condition1 and condition2 and not in_ring: + atom.SetProp('FG', 'aldoxime') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'aldoxime') + if condition1 and condition3 and not in_ring: + atom.SetProp('FG', 'ketoxime') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'ketoxime') + + ########################### Groups containing sulfur ########################### + elif atom_symbol == 'S' and charge == 0: + num_C, num_S, num_O = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if neighbor.GetSymbol() == 'S': + num_S += 1 + if neighbor.GetSymbol() == 'O': + num_O += 1 + + if num_H == 1 and atom_num_neighbors == 1 and atom.GetProp('FG') not in ['carbothioic_S-acid', 'carbodithioic_acid']: + neighbor = atom_neighbors[0] + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + atom.SetProp('FG', 'sulfhydryl') + + if num_H == 0 and atom_num_neighbors == 2 and atom.GetProp('FG') not in ['sulfhydrylester', 'carbodithio']: + cnt = 0 + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + cnt += 1 + if cnt == 2: + atom.SetProp('FG', 'sulfide') + + if num_H == 0 and num_S == 1 and atom_num_neighbors == 2: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2: + condition1 = True + if neighbor.GetSymbol() != 'S' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition2 = True + if condition1 and condition2: + atom.SetProp('FG', 'disulfide') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'S': + neighbor.SetProp('FG', 'disulfide') + + if num_H == 0 and num_O >= 1 and atom_num_neighbors == 3: + condition = False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition = True + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + cnt += 1 + if condition and cnt == 2: + atom.SetProp('FG', 'sulfinyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'sulfinyl') + + if num_H == 0 and num_O >= 2 and atom_num_neighbors == 4: + cnt1 = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + cnt1 += 1 + if cnt1 == 2: + atom.SetProp('FG', 'sulfonyl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + neighbor.SetProp('FG', 'sulfonyl') + + if num_H == 0 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + condition2 = True + if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition3 = True + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp('FG', 'sulfino') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'sulfino') + + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + cnt += 1 + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition2 = True + if condition1 and condition2 and cnt == 2 and not in_ring: + atom.SetProp('FG', 'sulfonic_acid') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'sulfonic_acid') + + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + cnt += 1 + if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0: + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp('FG', 'sulfonate_ester') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'sulfonate_ester') + + if num_H == 0 and atom_num_neighbors == 2: + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(C_neighbor.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE and not in_ring: + atom.SetProp('FG', 'thiocyanate') + neighbor.SetProp('FG', 'thiocyanate') + C_neighbor.SetProp('FG', 'thiocyanate') + + ########################### Groups containing phosphorus ########################### + elif atom_symbol == 'P' and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if neighbor.GetSymbol() == 'O': + num_O += 1 + + if atom_num_neighbors == 3: + cnt = 0 + for neighbor in atom_neighbors: + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + cnt += 1 + if cnt == 3: + atom.SetProp('FG', 'phosphino') + + if num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + cnt += 1 + if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp('FG', 'phosphono') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'phosphono') + + if num_O == 4 and atom_num_neighbors == 4: + condition1 = False + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + condition1 = True + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + cnt1 += 1 + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0: + cnt2 += 1 + + if condition1 and cnt1 == 2 and cnt2 == 1: + atom.SetProp('FG', 'phosphate') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'phosphate') + if condition1 and cnt1 == 1 and cnt2 == 2: + atom.SetProp('FG', 'phosphodiester') + for neighbor in atom_neighbors: + neighbor.SetProp('FG', 'phosphodiester') + + if num_O == 1 and atom_num_neighbors == 4: + condition = False + cnt = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + condition = True + if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + cnt += 1 + if condition and cnt == 3: + atom.SetProp('FG', 'phosphoryl') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'phosphoryl') + + ########################### Groups containing boron ########################### + elif atom_symbol == 'B' and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if neighbor.GetSymbol() == 'O': + num_O += 1 + + if num_O == 2 and atom_num_neighbors == 3: + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + cnt1 += 1 + if neighbor.GetSymbol() == 'O' and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 2: + cnt2 += 1 + if cnt1 == 2: + atom.SetProp('FG', 'borono') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'borono') + if cnt2 == 2: + atom.SetProp('FG', 'boronate') + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + neighbor.SetProp('FG', 'boronate') + + if num_O == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and neighbor.GetFormalCharge() == 0: + if neighbor.GetTotalNumHs() == 1: + atom.SetProp('FG', 'borino') + neighbor.SetProp('FG', 'borino') + if len(neighbor.GetNeighbors()) == 2: + atom.SetProp('FG', 'borinate') + neighbor.SetProp('FG', 'borinate') + + ########################### Groups containing silicon ########################### + elif atom_symbol =='Si' and not in_ring and charge == 0: + num_O, num_Cl, num_C = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O': + num_O += 1 + if neighbor.GetSymbol() == 'Cl': + num_Cl += 1 + if neighbor.GetSymbol() in ['C', '*']: + num_C += 1 + if num_O == 1 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'O' and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'silyl_ether') + neighbor.SetProp('FG', 'silyl_ether') + if num_Cl == 2 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == 'Cl' and neighbor.GetFormalCharge() == 0: + atom.SetProp('FG', 'dichlorosilane') + neighbor.SetProp('FG', 'dichlorosilane') + if num_C >= 3 and charge == 0 and atom_num_neighbors == 4 and atom.GetProp('FG') != 'silyl_ether': + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 3: + cnt += 1 + C_idx.append(neighbor.GetIdx()) + if cnt == 3: + atom.SetProp('FG', 'trimethylsilyl') + for idx in C_idx: + mol.GetAtomWithIdx(idx).SetProp('FG', 'trimethylsilyl') + + + ########################### Groups containing halogen ########################### + elif atom_symbol == 'F' and not in_ring and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'fluoro') + elif atom_symbol == 'Cl' and not in_ring and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'chloro') + elif atom_symbol == 'Br' and not in_ring and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'bromo') + elif atom_symbol == 'I' and not in_ring and charge == 0 and atom.GetProp('FG') == '': + atom.SetProp('FG', 'iodo') + else: + pass + + ########################### Groups containing other elements ########################### + if atom.GetProp('FG') == '' and atom_symbol in ELEMENTS and not in_ring: + if charge == 0: + atom.SetProp('FG', atom_symbol) + else: + atom.SetProp('FG', f'{atom_symbol}[{charge}]') + else: + pass + + if atom_symbol == '*': + atom.SetProp('FG', '') + +test_case = { + 'hydroxyl': 'CCCCO', + 'ether': 'CCCCOC', + 'peroxy': 'CCCCOOCCCC', + 'hydroperoxy': 'CCCCCCCOO', + 'haloformyl': 'CCCCC(=O)F', + 'ketone': 'CCCC(=O)CCCC', + 'aldehyde': 'CCC(=O)', + 'carboxylate': 'CCCCC(=O)[O-]', + 'carboxyl': 'CCCC(=O)O', + 'ester': 'CC(=O)OCCCCC', + 'hemiketal': 'CCCC(OC)(O)CCC', + 'ketal': 'CCCC(OCCC)(OCC)CCC', + 'carbonate_ester': 'C(=O)(OC(Cl)(Cl)Cl)OC(Cl)(Cl)Cl', + 'hemiacetal': 'CCCCC(OCCCC)(O)', + 'acetal': 'CCCCC(OCCC)(OCCC)', + 'orthoester': 'CC(OC)(OC)(OC)', + 'orthocarbonate_ester': 'C(OCCCC)(OCC)(OCC)(OCC)', + 'carboxylic_anhydride': 'C1CCC(CC1)C(=O)OC(=O)C2CCCCC2', + 'primary_amine': 'CCCCCCN', + 'secondary_amine': 'CCCCCCNCCC', + 'tertiary_amine': 'CCCCCCN(CCC)CCC', + '4_ammonium_ion': 'CCCCCC[N+](CC)(CCC)CCC', + 'hydrazone': 'CCCC(CCC)=NN', + 'primary_ketimine': 'CCCC(=N)CC', + 'secondary_ketimine': 'CCCC(=NCCC)CC', + 'primary_aldimine': 'CCCC(=N)', + 'secondary_aldimine': 'CCCC=NCCCC', + 'imide': 'CCC(=O)N(CCCC)C(=O)CCC', + 'amide': 'CCCC(=O)N(CCC)CCCCC', + 'amidine': 'CCCN=C(CC)N(CCCCC)CCC', + 'azide': 'C1=CC=C(C=C1)N=[N+]=[N-]', + 'azo': 'CN(C)C1=CC=C(C=C1)N=NC2=CC=C(C=C2)S(=O)(=O)[O-]', + 'cyanate': 'c1ccccc1COC#N', + 'isocyanate': 'CCCN=C=O', + 'nitrate': 'CCCCCO[N+](=O)[O-]', + 'nitrile': 'CCC#N', + 'isonitrile': 'CC[N+]#[C-]', + 'nitrosooxy': 'CC(C)CCON=O', + 'nitro': 'C[N+](=O)[O-]', + 'nitroso': 'C1=CC=C(C=C1)N=O', + 'aldoxime': 'CCCC=NO', + 'ketoxime': 'CCC(CCC)=NO', + 'carbamate': 'CC(C)OC(=O)N(CCC)C1=CC(=CC=C1)Cl', + 'sulfhydryl': 'CCCCCS', + 'sulfide': 'CSC', + 'disulfide': 'CSSC', + 'sulfinyl': 'CS(=O)C', + 'sulfonyl': 'CCCS(=O)(=O)CCCC', + 'sulfino': 'CCCCS(=O)O', + 'sulfonic_acid': 'CCCCS(=O)(=O)O', + 'sulfonate_ester': 'CCCS(=O)(=O)OCCCCC', + 'thiocyanate': 'CCCCSC#N', + 'isothiocyanate': 'c1ccccc1N=C=S', + 'thioketone': 'CCC(=S)CCCC', + 'thial': 'CCCC=S', + 'carbothioic_S-acid': 'CCC(=O)S', + 'carbothioic_O-acid': 'CCC(=S)O', + 'thiolester': 'CCC(=O)SCCC', + 'thionoester':'CCC(=S)OCCC', + 'carbodithioic_acid': 'CCCC(=S)S', + 'carbodithio': 'CCCC(=S)SCC', + 'phosphino': 'CCCCP(CCCC)CCCC', + 'phosphono': 'CCCP(=O)(O)O', + 'phosphate': 'CCCOP(=O)(O)O', + 'phosphodiester': 'CCCOP(=O)(O)OCCC', + 'phosphoryl': 'CCCP(=O)(CCC)CCC', + 'borono': 'c1ccccc1B(O)O', + 'boronate': 'CCCB(OCC)OCCC', + 'borino': 'CCCB(CCCC)O', + 'borinate': 'CCCB(CCCC)OCCC', + 'silyl_ether': 'C[Si](C)(C)OS(=O)(=O)C(F)(F)F', + 'dichlorosilane': 'CCCC[Si](Cl)(Cl)CCCC', + 'trimethylsilyl': 'CCCC[Si](C)(C)C', + 'fluoro': 'CF', + 'chloro': 'CCCCl', + 'bromo': 'CBr', + 'iodo': 'CCCI', + 'trifluoromethyl': 'CCCC(F)(F)F', + 'difluorochloromethyl': 'CCC(F)(F)Cl', + 'bromodifluoromethyl': 'CCC(F)(F)Br', + 'trichloromethyl': 'CCC(Cl)(Cl)Cl', + 'bromodichloromethyl': 'CCC(Cl)(Cl)Br', + 'tribromomethyl': 'CCC(Br)(Br)Br', + 'dibromofluoromethyl': 'CCCC(F)(Br)Br', + 'triiodomethyl': 'CCC(I)(I)I', + 'difluoromethyl': 'CCC(F)F', + 'fluorochloromethyl': 'CCC(F)Cl', + 'dichloromethyl': 'CCCC(Cl)Cl', + 'chlorobromomethyl': 'CCCC(Cl)Br', + 'chloroiodomethyl': 'CCCC(Cl)I', + 'dibromomethyl': 'CCCCC(Br)Br', + 'bromoiodomethyl': 'CCCC(Br)I', + 'diiodomethyl': 'CCCCC(I)I' +} + +def has_ring(mol): + if mol.GetRingInfo().NumRings() > 0: + return True + else: + return False + +def ring_separation(mol): + AllChem.GetSymmSSSR(mol) # type: ignore + rings = mol.GetRingInfo().AtomRings() + + splitting_bonds = [] + if mol is not None: + for bond in mol.GetBonds(): + begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + if (begin_atom.IsInRing() and not end_atom.IsInRing()) or (not begin_atom.IsInRing() and end_atom.IsInRing()) or (begin_atom.IsInRing() and end_atom.IsInRing()): + flag = True + for ring in rings: + if begin_atom.GetIdx() in ring and end_atom.GetIdx() in ring: + flag = False + break + if flag: + splitting_bonds.append(bond) + + if len(splitting_bonds) > 0: + fragments = FragmentOnBonds(mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True) + SMILES = m2s(fragments).split('.') + SMILES = [re.sub(r'\[\d+\*\]', '[*]', i) for i in SMILES] + SMILES = [m2s(s2m(i)) for i in SMILES] + return SMILES + else: + return None + +def set_atom_map_num(mol): + if mol is not None: + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + if idx != 0: + atom.SetAtomMapNum(idx) + else: + atom.SetAtomMapNum(-9) + +def find_neighbor_map(smiles): + matches = re.findall(r'\[(\d+)\*\]', smiles) + idx = [int(match) for match in matches] + if smiles.startswith('*'): + return set(idx) | {0} + else: + return set(idx) + +def find_atom_map(smiles): + matches = re.findall(r'\[[^\]]*:(\d+)\]', smiles) + idx = [int(match) for match in matches] + return set(idx) + +def get_scaffold(mol): + scaffold = MurckoScaffold.GetScaffoldForMol(mol) + AllChem.GetSymmSSSR(scaffold) # type: ignore + rings = scaffold.GetRingInfo().AtomRings() + + editable_mol = Chem.EditableMol(scaffold) # type: ignore + delete_idx = set() + + for bond in scaffold.GetBonds(): + begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + if bond.GetBondType() == Chem.BondType.DOUBLE: # type: ignore + if begin_atom.IsInRing() and not end_atom.IsInRing(): + delete_idx.add(end_atom.GetIdx()) + if not begin_atom.IsInRing() and end_atom.IsInRing(): + delete_idx.add(begin_atom.GetIdx()) + if len(begin_atom.GetNeighbors()) == 1: + delete_idx.add(begin_atom.GetIdx()) + if len(end_atom.GetNeighbors()) == 1: + delete_idx.add(end_atom.GetIdx()) + + if scaffold is not None: + for atom in scaffold.GetAtoms(): + bonds = atom.GetBonds() + cnt = 0 + for bond in bonds: + if bond.GetBondType() == Chem.BondType.SINGLE: # type: ignore + cnt += 1 + if not atom.IsInRing() and cnt > 2: + flag = False + for bond in bonds: + begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + if begin_atom.IsInRing(): + flag = True + for ring in rings: + if begin_atom.GetIdx() in ring: + delete_idx.add(begin_atom.GetIdx()) + if end_atom.IsInRing(): + flag = True + for ring in rings: + if end_atom.GetIdx() in ring: + for r in ring: + delete_idx.add(r) + if flag: + break + + + delete_idx = list(delete_idx) + delete_idx.sort(reverse=True) + for atom_idx in delete_idx: + editable_mol.RemoveAtom(atom_idx) + + return editable_mol.GetMol() + +def get_structure(mol): + set_atom_map_num(mol) + detect_functional_group(mol) + rings = mol.GetRingInfo().AtomRings() + + splitting_bonds = set() + for bond in mol.GetBonds(): + begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() + begin_atom_prop = begin_atom.GetProp('FG') + end_atom_prop = end_atom.GetProp('FG') + begin_atom_symbol = begin_atom.GetSymbol() + end_atom_symbol = end_atom.GetSymbol() + + if (begin_atom.IsInRing() and not end_atom.IsInRing()) or (not begin_atom.IsInRing() and end_atom.IsInRing()) or (begin_atom.IsInRing() and end_atom.IsInRing()): + flag = True + for ring in rings: + if begin_atom.GetIdx() in ring and end_atom.GetIdx() in ring: + flag = False + break + if flag: + splitting_bonds.add(bond) + else: + if begin_atom_prop != end_atom_prop: + splitting_bonds.add(bond) + if begin_atom_prop == '' and end_atom_prop == '': + if (begin_atom_symbol in ['C', '*'] and end_atom_symbol != 'C') or (begin_atom_symbol != 'C' and end_atom_symbol in ['C', '*']): + splitting_bonds.add(bond) + + splitting_bonds = list(splitting_bonds) + if splitting_bonds != []: + fragments = Chem.FragmentOnBonds(mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True) + BONDS = set() + for bond in splitting_bonds: + BONDS.add((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType())) + else: + fragments = mol + BONDS = set() + smiles = m2s(fragments).replace('-9', '0').split('.') + + structure = {} + for frag in smiles: + atom_idx, neighbor_idx = set(), set() + atom_idx = find_atom_map(frag) + neighbor_idx = find_neighbor_map(frag) + structure[frag] = {'atom': atom_idx, 'neighbor': neighbor_idx} + + return structure, BONDS + +def preprocess_smiles(smiles): + mol = s2m(smiles) + if mol is not None: + for atom in mol.GetAtoms(): + atom.SetAtomMapNum(0) + smiles = m2s(mol) + smiles = re.sub(r'\[\d+\*\]', '[*]', smiles) + smiles = m2s(s2m(smiles)) + return smiles + +def preprocess_mol(mol): + MOL = deepcopy(mol) + if mol is not None: + for atom in mol.GetAtoms(): + if atom.GetSymbol() == '*': + atom.SetAtomicNum(1) + mol_ = s2m(m2s(MOL)) + if mol_ is None: + print(m2s(mol)) + for atom in MOL.GetAtoms(): + if atom.GetSymbol() == '*': + atom.SetAtomicNum(6) + mol_ = s2m(m2s(MOL)) + else: + del MOL + return mol_ + +def remove_wildcards(mol): + editable_mol = Chem.EditableMol(mol) + wildcard_indices = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomicNum() == 0] + for idx in sorted(wildcard_indices, reverse=True): + editable_mol.RemoveAtom(idx) + return editable_mol.GetMol() + +def get_ring_structure(mol): + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() != 6: + atom.SetAtomicNum(6) + for bond in mol.GetBonds(): + if bond.GetIsAromatic(): + bond.SetIsAromatic(False) + if bond.GetBondType() != Chem.BondType.SINGLE: + bond.SetBondType(Chem.BondType.SINGLE) + return mol + +def get_core_structure(mol): + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() != 6: + atom.SetAtomicNum(6) + for bond in mol.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + bond.SetBondType(Chem.BondType.SINGLE) + return mol + +def get_new_smiles_rep(mol): + def replace_pattern(match): + number = int(match.group(1)) + return feature_idx.get(number, f"UNK") + + if mol is not None: + detect_functional_group(mol) + feature_idx = dict() + + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + if idx == 0: + idx = -9 + atom.SetAtomMapNum(idx) + + symbol = atom.GetSymbol() + if atom.GetIsAromatic(): + symbol = symbol.lower() + fg = atom.GetProp('FG') if atom.HasProp('FG') else '' + ring = atom.GetProp('RING') if atom.HasProp('RING') else '' + + if fg != '' and ring != '': + feature = symbol + '_' + fg + '_' + ring + elif fg != '' and ring == '': + feature = symbol + '_' + fg + elif fg == '' and ring != '': + feature = symbol + '_' + ring + else: + feature = symbol + feature_idx[idx] = ' ' + feature + ' ' + + smiles = m2s(mol) + feature_idx[0] = feature_idx[-9] + smiles = smiles.replace('-9', '0') + smiles = re.sub(r'\[.*?:(\d+)\]', replace_pattern, smiles) + smiles = re.sub(r'\s+', ' ', smiles) + smiles_list = [] + for t in smiles.split(' '): + if '_' not in t and len(t) > 1: + smiles_list.extend([char + ' ' for char in t]) + else: + smiles_list.append(t + ' ') + + new_smiles = ''.join(smiles_list).strip() + return new_smiles \ No newline at end of file From fb724892a0309c9a499e8fe36a9baebded5280f5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 15 Apr 2025 19:59:10 +0200 Subject: [PATCH 002/224] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..2ee15ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,25 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "24.2.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace \ No newline at end of file From 85b9d8afbf070287b8707724646a5c57799f231e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 19 Apr 2025 16:13:00 +0200 Subject: [PATCH 003/224] Update .gitignore --- .gitignore | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) diff --git a/.gitignore b/.gitignore index a09c56d..f9cb175 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,169 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + /.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt From 4c6a1910e40e5dd6b4ae972b0cdedc81ffe700b8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 19 Apr 2025 16:15:20 +0200 Subject: [PATCH 004/224] remove redundant code functions --- .../preprocessing/fg_detection/rule_based.py | 2463 +++++++++++------ 1 file changed, 1561 insertions(+), 902 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/rule_based.py b/chebai_graph/preprocessing/fg_detection/rule_based.py index 4c7236a..3a1083a 100644 --- a/chebai_graph/preprocessing/fg_detection/rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/rule_based.py @@ -3,121 +3,111 @@ # arXiv preprint arXiv:2410.02082 (2024). import re + from rdkit import Chem -from rdkit.Chem import MolToSmiles as m2s -from rdkit.Chem import MolFromSmiles as s2m -from rdkit.Chem import FragmentOnBonds -from rdkit.Chem.Scaffolds import MurckoScaffold from rdkit.Chem import AllChem -from copy import deepcopy +from rdkit.Chem import MolToSmiles as m2s electronegativity = { - 'H': 2.2, - 'LI': 0.98, - 'BE': 1.57, - 'B': 2.04, - 'C': 2.55, - 'N': 3.04, - 'O': 3.44, - 'F': 3.98, - 'NA': 0.93, - 'MG': 1.31, - 'AL': 1.61, - 'SI': 1.9, - 'P': 2.19, - 'S': 2.58, - 'CL': 3.16, - 'K': 0.82, - 'CA': 1.0, - 'SC': 1.36, - 'TI': 1.54, - 'V': 1.63, - 'CR': 1.66, - 'MN': 1.55, - 'FE': 1.83, - 'CO': 1.88, - 'NI': 1.91, - 'CU': 1.9, - 'ZN': 1.65, - 'GA': 1.81, - 'GE': 2.01, - 'AS': 2.18, - 'SE': 2.55, - 'BR': 2.96, - 'RB': 0.82, - 'SR': 0.95, - 'Y': 1.22, - 'ZR': 1.33, - 'NB': 1.6, - 'MO': 2.16, - 'TC': 1.9, - 'RU': 2.2, - 'RH': 2.28, - 'PD': 2.2, - 'AG': 1.93, - 'CD': 1.69, - 'IN': 1.78, - 'SN': 1.96, - 'SB': 2.05, - 'TE': 2.1, - 'I': 2.66, - 'CS': 0.79, - 'BA': 0.89, - 'LA': 1.1, - 'CE': 1.12, - 'PR': 1.13, - 'ND': 1.14, - 'PM': 1.13, - 'SM': 1.17, - 'EU': 1.2, - 'GD': 1.2, - 'TB': 1.1, - 'DY': 1.22, - 'HO': 1.23, - 'ER': 1.24, - 'TM': 1.25, - 'YB': 1.1, - 'LU': 1.27, - 'HF': 1.3, - 'TA': 1.5, - 'W': 2.36, - 'RE': 1.9, - 'OS': 2.2, - 'IR': 2.2, - 'PT': 2.28, - 'AU': 2.54, - 'HG': 2.0, - 'TL': 1.62, - 'PB': 2.33, - 'BI': 2.02, - 'PO': 2.0, - 'AT': 2.2, - 'FR': 0.7, - 'RA': 0.9, - 'AC': 1.1, - 'TH': 1.3, - 'PA': 1.5, - 'U': 1.38, - 'NP': 1.36, - 'PU': 1.28, - 'AM': 1.3, - 'CM': 1.3, - 'BK': 1.3, - 'CF': 1.3, - 'ES': 1.3, - 'FM': 1.3, - 'MD': 1.3, - 'NO': 1.3, - 'LR': 1.3 + "H": 2.2, + "LI": 0.98, + "BE": 1.57, + "B": 2.04, + "C": 2.55, + "N": 3.04, + "O": 3.44, + "F": 3.98, + "NA": 0.93, + "MG": 1.31, + "AL": 1.61, + "SI": 1.9, + "P": 2.19, + "S": 2.58, + "CL": 3.16, + "K": 0.82, + "CA": 1.0, + "SC": 1.36, + "TI": 1.54, + "V": 1.63, + "CR": 1.66, + "MN": 1.55, + "FE": 1.83, + "CO": 1.88, + "NI": 1.91, + "CU": 1.9, + "ZN": 1.65, + "GA": 1.81, + "GE": 2.01, + "AS": 2.18, + "SE": 2.55, + "BR": 2.96, + "RB": 0.82, + "SR": 0.95, + "Y": 1.22, + "ZR": 1.33, + "NB": 1.6, + "MO": 2.16, + "TC": 1.9, + "RU": 2.2, + "RH": 2.28, + "PD": 2.2, + "AG": 1.93, + "CD": 1.69, + "IN": 1.78, + "SN": 1.96, + "SB": 2.05, + "TE": 2.1, + "I": 2.66, + "CS": 0.79, + "BA": 0.89, + "LA": 1.1, + "CE": 1.12, + "PR": 1.13, + "ND": 1.14, + "PM": 1.13, + "SM": 1.17, + "EU": 1.2, + "GD": 1.2, + "TB": 1.1, + "DY": 1.22, + "HO": 1.23, + "ER": 1.24, + "TM": 1.25, + "YB": 1.1, + "LU": 1.27, + "HF": 1.3, + "TA": 1.5, + "W": 2.36, + "RE": 1.9, + "OS": 2.2, + "IR": 2.2, + "PT": 2.28, + "AU": 2.54, + "HG": 2.0, + "TL": 1.62, + "PB": 2.33, + "BI": 2.02, + "PO": 2.0, + "AT": 2.2, + "FR": 0.7, + "RA": 0.9, + "AC": 1.1, + "TH": 1.3, + "PA": 1.5, + "U": 1.38, + "NP": 1.36, + "PU": 1.28, + "AM": 1.3, + "CM": 1.3, + "BK": 1.3, + "CF": 1.3, + "ES": 1.3, + "FM": 1.3, + "MD": 1.3, + "NO": 1.3, + "LR": 1.3, } -def sdf2smiles(sdf_file): - SMILES = set() - supplier = Chem.SDMolSupplier(sdf_file) - for mol in supplier: - if mol is not None: - SMILES.add(Chem.MolToSmiles(mol)) - return SMILES def ring_size_processing(ring_size): if ring_size[0] > ring_size[-1]: @@ -125,6 +115,7 @@ def ring_size_processing(ring_size): else: return ring_size + # Function to find all rings connected to a given ring def find_connected_rings(ring, remaining_rings): connected_rings = [ring] @@ -139,23 +130,118 @@ def find_connected_rings(ring, remaining_rings): merged = True return connected_rings -def detect_functional_group(mol): # type: ignore - AllChem.GetSymmSSSR(mol) # type: ignore - ELEMENTS = set([ - 'Ac', 'Ag', 'Al', 'Am', 'As', 'At', 'Au', 'B', 'Ba', 'Be', 'Bi', 'Bk', 'Br', - 'Ca', 'Cd', 'Ce', 'Cf', 'Cl', 'Cm', 'Co', 'Cr', 'Cs', 'Cu', 'Dy', 'Er', - 'Es', 'Eu', 'F', 'Fe', 'Fm', 'Fr', 'Ga', 'Gd', 'Ge', 'He', 'Hf', 'Hg', - 'Ho', 'I', 'In', 'Ir', 'K', 'Kr', 'La', 'Li', 'Lr', 'Lu', 'Md', 'Mg', 'Mn', - 'Mo', 'N', 'Na', 'Nb', 'Nd', 'Ne', 'Ni', 'Np', 'O', 'Os', 'P', 'Pa', 'Pb', - 'Pd', 'Pm', 'Po', 'Pr', 'Pt', 'Pu', 'Ra', 'Rb', 'Re', 'Rh', 'Rn', 'Ru', 'S', - 'Sb', 'Sc', 'Se', 'Si', 'Sm', 'Sn', 'Sr', 'Ta', 'Tb', 'Tc', 'Te', 'Th', 'Ti', - 'Tl', 'Tm', 'U', 'V', 'W', 'Xe', 'Y', 'Yb', 'Zn', 'Zr']) - + +def detect_functional_group(mol): # type: ignore + AllChem.GetSymmSSSR(mol) # type: ignore + ELEMENTS = set( + [ + "Ac", + "Ag", + "Al", + "Am", + "As", + "At", + "Au", + "B", + "Ba", + "Be", + "Bi", + "Bk", + "Br", + "Ca", + "Cd", + "Ce", + "Cf", + "Cl", + "Cm", + "Co", + "Cr", + "Cs", + "Cu", + "Dy", + "Er", + "Es", + "Eu", + "F", + "Fe", + "Fm", + "Fr", + "Ga", + "Gd", + "Ge", + "He", + "Hf", + "Hg", + "Ho", + "I", + "In", + "Ir", + "K", + "Kr", + "La", + "Li", + "Lr", + "Lu", + "Md", + "Mg", + "Mn", + "Mo", + "N", + "Na", + "Nb", + "Nd", + "Ne", + "Ni", + "Np", + "O", + "Os", + "P", + "Pa", + "Pb", + "Pd", + "Pm", + "Po", + "Pr", + "Pt", + "Pu", + "Ra", + "Rb", + "Re", + "Rh", + "Rn", + "Ru", + "S", + "Sb", + "Sc", + "Se", + "Si", + "Sm", + "Sn", + "Sr", + "Ta", + "Tb", + "Tc", + "Te", + "Th", + "Ti", + "Tl", + "Tm", + "U", + "V", + "W", + "Xe", + "Y", + "Yb", + "Zn", + "Zr", + ] + ) + if mol is not None: for atom in mol.GetAtoms(): - atom.SetProp('FG', '') - atom.SetProp('RING', '') - + atom.SetProp("FG", "") + atom.SetProp("RING", "") + ######## SET RING PROP ######## # Get ring information ring_info = mol.GetRingInfo() @@ -183,11 +269,11 @@ def detect_functional_group(mol): # type: ignore # Display the fused ring blocks and their ring sizes for i, block in enumerate(fused_ring_blocks): - rs = '-'.join(str(size) for size in ring_size_processing(ring_sizes[i])) + rs = "-".join(str(size) for size in ring_size_processing(ring_sizes[i])) for idx in block: atom = mol.GetAtomWithIdx(idx) - atom.SetProp('RING', rs) - + atom.SetProp("RING", rs) + ######## SET FUNCTIONAL GROUP PROP ######## for atom in mol.GetAtoms(): atom_symbol = atom.GetSymbol() @@ -197,318 +283,684 @@ def detect_functional_group(mol): # type: ignore in_ring = atom.IsInRing() atom_idx = atom.GetIdx() charge = atom.GetFormalCharge() - + ########################### Groups containing oxygen ########################### - if atom_symbol in ['C', '*'] and charge == 0: # and atom.GetProp('FG') == '': + if ( + atom_symbol in ["C", "*"] and charge == 0 + ): # and atom.GetProp('FG') == '': num_O, num_X, num_C, num_N, num_S = 0, 0, 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['F', 'Cl', 'Br', 'I']: + if neighbor.GetSymbol() in ["F", "Cl", "Br", "I"]: num_X += 1 - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 - if neighbor.GetSymbol() == 'N': + if neighbor.GetSymbol() == "N": num_N += 1 - if neighbor.GetSymbol() == 'S': + if neighbor.GetSymbol() == "S": num_S += 1 - - if num_H == 1 and atom_num_neighbors == 3 and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'tertiary_carbon') - if atom_num_neighbors == 4 and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'quaternary_carbon') - if num_H == 0 and atom_num_neighbors == 3 and charge == 0 and atom.GetProp('FG') == '' and not in_ring: - atom.SetProp('FG', 'alkene_carbon') - - if num_O == 1 and atom_symbol == 'C' and atom.GetProp('FG') not in ['hemiacetal', 'hemiketal', 'acetal', 'ketal', 'orthoester', 'orthocarbonate_ester', 'carbonate_ester']: - if num_N == 1: # Cyanate and Isocyanate + + if ( + num_H == 1 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "tertiary_carbon") + if atom_num_neighbors == 4 and charge == 0 and atom.GetProp("FG") == "": + atom.SetProp("FG", "quaternary_carbon") + if ( + num_H == 0 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + and not in_ring + ): + atom.SetProp("FG", "alkene_carbon") + + if ( + num_O == 1 + and atom_symbol == "C" + and atom.GetProp("FG") + not in [ + "hemiacetal", + "hemiketal", + "acetal", + "ketal", + "orthoester", + "orthocarbonate_ester", + "carbonate_ester", + ] + ): + if num_N == 1: # Cyanate and Isocyanate condition1, condition2 = False, False - condition3, condition4= False, False + condition3, condition4 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition2 = True - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition3 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): condition4 = True - - if condition1 and condition2 and not in_ring: # Cyanate - atom.SetProp('FG', 'cyanate') + + if condition1 and condition2 and not in_ring: # Cyanate + atom.SetProp("FG", "cyanate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'cyanate') + neighbor.SetProp("FG", "cyanate") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() in ['C', '*'] and C_neighbor.GetIdx() != atom_idx: - C_neighbor.SetProp('FG', '') - - if condition3 and condition4 and not in_ring: # Isocyanate - atom.SetProp('FG', 'isocyanate') + if ( + C_neighbor.GetSymbol() in ["C", "*"] + and C_neighbor.GetIdx() != atom_idx + ): + C_neighbor.SetProp("FG", "") + + if condition3 and condition4 and not in_ring: # Isocyanate + atom.SetProp("FG", "isocyanate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'isocyanate') + neighbor.SetProp("FG", "isocyanate") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": bond = mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()) bondtype = bond.GetBondType() - if bondtype == Chem.BondType.SINGLE: # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] - if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] - neighbor.SetProp('FG', 'hydroxyl') - else: + if ( + bondtype == Chem.BondType.SINGLE + ): # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] + if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] + neighbor.SetProp("FG", "hydroxyl") + else: for O_neighbor in neighbor.GetNeighbors(): # if not O_neighbor.IsInRing(): - if O_neighbor.GetIdx() != atom_idx and O_neighbor.GetSymbol() in ['C', '*'] and neighbor.GetProp('FG') == '': # Ether [COC] - neighbor.SetProp('FG', 'ether') - if O_neighbor.GetSymbol() == 'O': - if O_neighbor.GetTotalNumHs() == 1: # Hydroperoxy [C-O-O-H] - neighbor.SetProp('FG', 'hydroperoxy') - O_neighbor.SetProp('FG', 'hydroperoxy') + if ( + O_neighbor.GetIdx() != atom_idx + and O_neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetProp("FG") == "" + ): # Ether [COC] + neighbor.SetProp("FG", "ether") + if O_neighbor.GetSymbol() == "O": + if ( + O_neighbor.GetTotalNumHs() == 1 + ): # Hydroperoxy [C-O-O-H] + neighbor.SetProp("FG", "hydroperoxy") + O_neighbor.SetProp("FG", "hydroperoxy") else: - neighbor.SetProp('FG', 'peroxy') - O_neighbor.SetProp('FG', 'peroxy') - - if bondtype == Chem.BondType.DOUBLE: # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] - if num_X == 1 and not neighbor.IsInRing(): # Acyl halide [C(=O)X] - atom.SetProp('FG', 'haloformyl') + neighbor.SetProp("FG", "peroxy") + O_neighbor.SetProp("FG", "peroxy") + + if ( + bondtype == Chem.BondType.DOUBLE + ): # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] + if ( + num_X == 1 and not neighbor.IsInRing() + ): # Acyl halide [C(=O)X] + atom.SetProp("FG", "haloformyl") for neighbor_ in atom_neighbors: - if neighbor_.GetSymbol() in ['O', 'F', 'Cl', 'Br', 'I']: - neighbor_.SetProp('FG', 'haloformyl') - - if (num_C == 1 and num_H == 1) or num_H == 2 and not in_ring: # Aldehyde [C(=O)H] - atom.SetProp('FG', 'aldehyde') - neighbor.SetProp('FG', 'aldehyde') - - if atom_num_neighbors == 3 and atom.GetProp('FG') not in ['haloformyl', 'amide']: # Ketone [C(=0)C] - atom.SetProp('FG', 'ketone') + if neighbor_.GetSymbol() in [ + "O", + "F", + "Cl", + "Br", + "I", + ]: + neighbor_.SetProp("FG", "haloformyl") + + if ( + (num_C == 1 and num_H == 1) + or num_H == 2 + and not in_ring + ): # Aldehyde [C(=O)H] + atom.SetProp("FG", "aldehyde") + neighbor.SetProp("FG", "aldehyde") + + if atom_num_neighbors == 3 and atom.GetProp( + "FG" + ) not in [ + "haloformyl", + "amide", + ]: # Ketone [C(=0)C] + atom.SetProp("FG", "ketone") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and not neighbor.IsInRing(): - neighbor.SetProp('FG', 'ketone') - - if num_O == 2: # and atom.GetProp('FG') == '': + if ( + neighbor.GetSymbol() == "O" + and not neighbor.IsInRing() + ): + neighbor.SetProp("FG", "ketone") + + if num_O == 2: # and atom.GetProp('FG') == '': if atom_num_neighbors == 3: if num_H == 0: - condition1, condition2, condition3, condition4 = False, False, False, False + condition1, condition2, condition3, condition4 = ( + False, + False, + False, + False, + ) for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + and not neighbor.IsInRing() + ): condition2 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): condition3 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 and atom.GetProp('FG') != 'carbamate': + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + and atom.GetProp("FG") != "carbamate" + ): condition4 = True if condition1 and condition2: - atom.SetProp('FG', 'carboxylate') + atom.SetProp("FG", "carboxylate") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'carboxylate') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "carboxylate") if condition1 and condition3: - atom.SetProp('FG', 'carboxyl') + atom.SetProp("FG", "carboxyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'carboxyl') - if condition1 and condition4 and atom.GetProp('FG') not in ['carbamate', 'carbonate_ester']: - atom.SetProp('FG', 'ester') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "carboxyl") + if ( + condition1 + and condition4 + and atom.GetProp("FG") + not in ["carbamate", "carbonate_ester"] + ): + atom.SetProp("FG", "ester") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'ester') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ester") for O_neighbor in neighbor.GetNeighbors(): - O_neighbor.SetProp('FG', 'ester') - + O_neighbor.SetProp("FG", "ester") + if num_H == 1 and not in_ring: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): condition2 = True cnt += 1 if condition1 and condition2: - atom.SetProp('FG', 'hemiacetal') + atom.SetProp("FG", "hemiacetal") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'hemiacetal') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "hemiacetal") if cnt == 2: - atom.SetProp('FG', 'acetal') + atom.SetProp("FG", "acetal") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'acetal') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "acetal") + if atom_num_neighbors == 4 and not in_ring: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): condition2 = True cnt += 1 if condition1 and condition2: - atom.SetProp('FG', 'hemiketal') + atom.SetProp("FG", "hemiketal") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'hemiketal') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "hemiketal") if cnt == 2: - atom.SetProp('FG', 'ketal') + atom.SetProp("FG", "ketal") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'ketal') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ketal") + if num_O == 3 and atom_num_neighbors == 4 and not in_ring: n_C = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): n_C += 1 if n_C == 3: - atom.SetProp('FG', 'orthoester') + atom.SetProp("FG", "orthoester") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'orthoester') - - if num_O == 3 and atom_num_neighbors == 3 and charge == 0 and not in_ring: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "orthoester") + + if ( + num_O == 3 + and atom_num_neighbors == 3 + and charge == 0 + and not in_ring + ): condition1 = False n_O = 0 for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): n_O += 1 if condition1 and n_O == 2: - atom.SetProp('FG', 'carbonate_ester') + atom.SetProp("FG", "carbonate_ester") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'carbonate_ester') + neighbor.SetProp("FG", "carbonate_ester") if num_O == 4 and not in_ring: n_C = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): n_C += 1 if n_C == 4: - atom.SetProp('FG', 'orthocarbonate_ester') + atom.SetProp("FG", "orthocarbonate_ester") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'orthocarbonate_ester') + neighbor.SetProp("FG", "orthocarbonate_ester") - ########################### Groups containing nitrogen ########################### + ########################### Groups containing nitrogen ########################### #### Amidine #### if num_N == 2 and atom_num_neighbors == 3: condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): condition1 = True - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): condition2 = True if condition1 and condition2: - atom.SetProp('FG', 'amidine') + atom.SetProp("FG", "amidine") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'N': - neighbor.SetProp('FG', 'amidine') - + if neighbor.GetSymbol() == "N": + neighbor.SetProp("FG", "amidine") + if num_N == 1 and num_O == 2 and atom_num_neighbors == 3: condition1, condition2, condition3 = False, False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 2: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): condition2 = True - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 3 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + and not neighbor.IsInRing() + ): condition3 = True if condition1 and condition2 and condition3: - atom.SetProp('FG', 'carbamate') + atom.SetProp("FG", "carbamate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'carbamate') - + neighbor.SetProp("FG", "carbamate") + if num_N == 1 and num_S == 1: condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and not neighbor.IsInRing(): - condition1 = True - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): condition2 = True if condition1 and condition2: - atom.SetProp('FG', 'isothiocyanate') + atom.SetProp("FG", "isothiocyanate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'isothiocyanate') - + neighbor.SetProp("FG", "isothiocyanate") + if num_S == 1 and atom_num_neighbors == 3: for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): - atom.SetProp('FG', 'thioketone') - neighbor.SetProp('FG', 'thioketone') - + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thioketone") + neighbor.SetProp("FG", "thioketone") + if num_S == 1 and num_H == 1 and atom_num_neighbors == 2: for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): - atom.SetProp('FG', 'thial') - neighbor.SetProp('FG', 'thial') - + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thial") + neighbor.SetProp("FG", "thial") + if num_S == 1 and num_O == 1 and atom_num_neighbors == 3: condition1, condition2 = False, False condition3, condition4 = False, False condition5, condition6 = False, False condition7, condition8 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 1 and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): condition2 = True - - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and not neighbor.IsInRing(): + + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): condition3 = True - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and not len(neighbor.GetNeighbors())==1: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and not len(neighbor.GetNeighbors()) == 1 + ): condition4 = True - - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetTotalNumHs() == 0 and not neighbor.IsInRing(): + + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): flag = True for bond in neighbor.GetBonds(): if bond.GetBondType() != Chem.BondType.SINGLE: flag = False if flag: condition5 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): condition6 = True - - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing(): + + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): condition7 = True - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors())==1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): condition8 = True if condition1 and condition2: - atom.SetProp('FG', 'carbothioic_S-acid') + atom.SetProp("FG", "carbothioic_S-acid") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['S', 'O']: - neighbor.SetProp('FG', 'carbothioic_S-acid') + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_S-acid") if condition3 and condition4: - atom.SetProp('FG', 'carbothioic_O-acid') + atom.SetProp("FG", "carbothioic_O-acid") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['S', 'O']: - neighbor.SetProp('FG', 'carbothioic_O-acid') + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_O-acid") if condition5 and condition6: - atom.SetProp('FG', 'thiolester') + atom.SetProp("FG", "thiolester") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['S', 'O']: - neighbor.SetProp('FG', 'thiolester') + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thiolester") if condition7 and condition8: - atom.SetProp('FG', 'thionoester') + atom.SetProp("FG", "thionoester") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['S', 'O']: - neighbor.SetProp('FG', 'thionoester') - + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thionoester") if num_S == 2 and atom_num_neighbors == 3: condition1, condition2, condition3 = False, False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and len(neighbor.GetNeighbors()) == 1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): condition1 = True - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors()) == 1 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): condition2 = True - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), atom_idx).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and len(neighbor.GetNeighbors()) == 2 and not neighbor.IsInRing(): + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): flag = True for bond in neighbor.GetBonds(): if bond.GetBondType() != Chem.BondType.SINGLE: @@ -517,128 +969,137 @@ def detect_functional_group(mol): # type: ignore condition3 = True if condition1 and condition2: - atom.SetProp('FG', 'carbodithioic_acid') + atom.SetProp("FG", "carbodithioic_acid") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S': - neighbor.SetProp('FG', 'carbodithioic_acid') + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithioic_acid") if condition3 and condition2: - atom.SetProp('FG', 'carbodithio') + atom.SetProp("FG", "carbodithio") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S': - neighbor.SetProp('FG', 'carbodithio') + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithio") if num_X == 3 and charge == 0 and atom_num_neighbors == 4: num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'F': + if neighbor.GetSymbol() == "F": num_F += 1 - if neighbor.GetSymbol() == 'Cl': + if neighbor.GetSymbol() == "Cl": num_Cl += 1 - if neighbor.GetSymbol() == 'Br': + if neighbor.GetSymbol() == "Br": num_Br += 1 - if neighbor.GetSymbol() == 'I': + if neighbor.GetSymbol() == "I": num_I += 1 if num_F == 3: - atom.SetProp('FG', 'trifluoromethyl') + atom.SetProp("FG", "trifluoromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'F': - neighbor.SetProp('FG', 'trifluoromethyl') + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "trifluoromethyl") if num_F == 2 and num_Cl == 1: - atom.SetProp('FG', 'difluorochloromethyl') + atom.SetProp("FG", "difluorochloromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['F', 'Cl']: - neighbor.SetProp('FG', 'difluorochloromethyl') + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "difluorochloromethyl") if num_F == 2 and num_Br == 1: - atom.SetProp('FG', 'bromodifluoromethyl') + atom.SetProp("FG", "bromodifluoromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['F', 'Br']: - neighbor.SetProp('FG', 'bromodifluoromethyl') + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "bromodifluoromethyl") if num_Cl == 3: - atom.SetProp('FG', 'trichloromethyl') + atom.SetProp("FG", "trichloromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'Cl': - neighbor.SetProp('FG', 'trichloromethyl') + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "trichloromethyl") if num_Cl == 2 and num_Br == 1: - atom.SetProp('FG', 'bromodichloromethyl') + atom.SetProp("FG", "bromodichloromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['Cl', 'Br']: - neighbor.SetProp('FG', 'bromodichloromethyl') - + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "bromodichloromethyl") + if num_Br == 3: - atom.SetProp('FG', 'tribromomethyl') + atom.SetProp("FG", "tribromomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'Br': - neighbor.SetProp('FG', 'tribromomethyl') + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "tribromomethyl") if num_Br == 2 and num_F == 1: - atom.SetProp('FG', 'dibromofluoromethyl') + atom.SetProp("FG", "dibromofluoromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['F', 'Br']: - neighbor.SetProp('FG', 'dibromofluoromethyl') - + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "dibromofluoromethyl") + if num_I == 3: - atom.SetProp('FG', 'triiodomethyl') + atom.SetProp("FG", "triiodomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'I': - neighbor.SetProp('FG', 'triiodomethyl') - - if num_X == 2 and charge == 0 and atom_num_neighbors == 3 and num_H == 1: + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "triiodomethyl") + + if ( + num_X == 2 + and charge == 0 + and atom_num_neighbors == 3 + and num_H == 1 + ): num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'F': + if neighbor.GetSymbol() == "F": num_F += 1 - if neighbor.GetSymbol() == 'Cl': + if neighbor.GetSymbol() == "Cl": num_Cl += 1 - if neighbor.GetSymbol() == 'Br': + if neighbor.GetSymbol() == "Br": num_Br += 1 - if neighbor.GetSymbol() == 'I': + if neighbor.GetSymbol() == "I": num_I += 1 - + if num_F == 2: - atom.SetProp('FG', 'difluoromethyl') + atom.SetProp("FG", "difluoromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'F': - neighbor.SetProp('FG', 'difluoromethyl') + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "difluoromethyl") if num_F == 1 and num_Cl == 1: - atom.SetProp('FG', 'fluorochloromethyl') + atom.SetProp("FG", "fluorochloromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['F', 'Cl']: - neighbor.SetProp('FG', 'fluorochloromethyl') - + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "fluorochloromethyl") + if num_Cl == 2: - atom.SetProp('FG', 'dichloromethyl') + atom.SetProp("FG", "dichloromethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'Cl': - neighbor.SetProp('FG', 'dichloromethyl') + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "dichloromethyl") if num_Cl == 1 and num_Br == 1: - atom.SetProp('FG', 'chlorobromomethyl') + atom.SetProp("FG", "chlorobromomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['Cl', 'Br']: - neighbor.SetProp('FG', 'chlorobromomethyl') + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "chlorobromomethyl") if num_Cl == 1 and num_I == 1: - atom.SetProp('FG', 'chloroiodomethyl') + atom.SetProp("FG", "chloroiodomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['Cl', 'I']: - neighbor.SetProp('FG', 'chloroiodomethyl') - + if neighbor.GetSymbol() in ["Cl", "I"]: + neighbor.SetProp("FG", "chloroiodomethyl") + if num_Br == 2: - atom.SetProp('FG', 'dibromomethyl') + atom.SetProp("FG", "dibromomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'Br': - neighbor.SetProp('FG', 'dibromomethyl') + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "dibromomethyl") if num_Br == 1 and num_I == 1: - atom.SetProp('FG', 'bromoiodomethyl') + atom.SetProp("FG", "bromoiodomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['Br', 'I']: - neighbor.SetProp('FG', 'bromoiodomethyl') - + if neighbor.GetSymbol() in ["Br", "I"]: + neighbor.SetProp("FG", "bromoiodomethyl") + if num_I == 2: - atom.SetProp('FG', 'diiodomethyl') + atom.SetProp("FG", "diiodomethyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'I': - neighbor.SetProp('FG', 'diiodomethyl') - - if (atom_num_neighbors == 2 or atom_num_neighbors == 1) and not in_ring and atom.GetProp('FG') == '': + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "diiodomethyl") + + if ( + (atom_num_neighbors == 2 or atom_num_neighbors == 1) + and not in_ring + and atom.GetProp("FG") == "" + ): bonds = atom.GetBonds() ns, nd, nt = 0, 0, 0 for bond in bonds: @@ -649,57 +1110,86 @@ def detect_functional_group(mol): # type: ignore else: nt += 1 if ns >= 1 and nd == 0 and nt == 0: - atom.SetProp('FG', 'alkyl') + atom.SetProp("FG", "alkyl") if nd >= 1: - atom.SetProp('FG', 'alkene') + atom.SetProp("FG", "alkene") if nt == 1: - atom.SetProp('FG', 'alkyne') - - elif atom_symbol == 'O' and not in_ring and charge == 0 and num_H == 0: # Carboxylic anhydride [C(CO)O(CO)C] + atom.SetProp("FG", "alkyne") + + elif ( + atom_symbol == "O" and not in_ring and charge == 0 and num_H == 0 + ): # Carboxylic anhydride [C(CO)O(CO)C] num_C = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 if num_C == 2: cnt = 0 for neighbor in atom_neighbors: for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), C_neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3: + if ( + C_neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), C_neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + ): cnt += 1 if cnt == 2: for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'carboxylic_anhydride') + neighbor.SetProp("FG", "carboxylic_anhydride") for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() == 'O': - C_neighbor.SetProp('FG', 'carboxylic_anhydride') + if C_neighbor.GetSymbol() == "O": + C_neighbor.SetProp("FG", "carboxylic_anhydride") - elif atom_symbol == 'N': # and atom.GetProp('FG') == '': + elif atom_symbol == "N": # and atom.GetProp('FG') == '': num_C, num_O, num_N = 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 - if neighbor.GetSymbol() == 'N': + if neighbor.GetSymbol() == "N": num_N += 1 - - #### Amines #### - if charge == 0 and num_H == 2 and atom_num_neighbors == 1 and atom.GetProp('FG') != 'hydrazone': # Primary amine [RNH2] - atom.SetProp('FG', 'primary_amine') - if charge == 0 and num_H == 1 and atom_num_neighbors == 2: # Secondary amine [R'R"NH] - atom.SetProp('FG', 'secondary_amine') - - if charge == 0 and atom_num_neighbors == 3 and atom.GetProp('FG') != 'carbamate': + #### Amines #### + if ( + charge == 0 + and num_H == 2 + and atom_num_neighbors == 1 + and atom.GetProp("FG") != "hydrazone" + ): # Primary amine [RNH2] + atom.SetProp("FG", "primary_amine") + + if ( + charge == 0 and num_H == 1 and atom_num_neighbors == 2 + ): # Secondary amine [R'R"NH] + atom.SetProp("FG", "secondary_amine") + + if ( + charge == 0 + and atom_num_neighbors == 3 + and atom.GetProp("FG") != "carbamate" + ): cnt = 0 C_idx = [] for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(neighbor.GetIdx(), C_neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0 and atom.GetProp('FG') != 'imide': - atom.SetProp('FG', 'amide') - neighbor.SetProp('FG', 'amide') - C_neighbor.SetProp('FG', 'amide') + if ( + C_neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), C_neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and atom.GetProp("FG") != "imide" + ): + atom.SetProp("FG", "amide") + neighbor.SetProp("FG", "amide") + C_neighbor.SetProp("FG", "amide") cnt += 1 C_idx.append(neighbor.GetIdx()) @@ -707,553 +1197,852 @@ def detect_functional_group(mol): # type: ignore for neighbor in atom_neighbors: if neighbor.GetIdx() in C_idx: for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() in ['O', 'N' ]: - neighbor.SetProp('FG', 'imide') - C_neighbor.SetProp('FG', 'imide') - - if atom.GetProp('FG') not in ['imide', 'amide', 'amidine', 'carbamate']: # Tertiary amine [R3N] - atom.SetProp('FG', 'tertiary_amine') - - if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] - atom.SetProp('FG', '4_ammonium_ion') - - if charge == 0 and num_C == 1 and num_N == 1 and num_H == 0 and atom_num_neighbors == 2: # Hydrazone [R'R"CN2H2] + if C_neighbor.GetSymbol() in ["O", "N"]: + neighbor.SetProp("FG", "imide") + C_neighbor.SetProp("FG", "imide") + + if atom.GetProp("FG") not in [ + "imide", + "amide", + "amidine", + "carbamate", + ]: # Tertiary amine [R3N] + atom.SetProp("FG", "tertiary_amine") + + if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] + atom.SetProp("FG", "4_ammonium_ion") + + if ( + charge == 0 + and num_C == 1 + and num_N == 1 + and num_H == 0 + and atom_num_neighbors == 2 + ): # Hydrazone [R'R"CN2H2] condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() == 'N' and neighbor.GetTotalNumHs() == 2 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "N" + and neighbor.GetTotalNumHs() == 2 + and neighbor.GetFormalCharge() == 0 + ): condition2 = True if condition1 and condition2: - atom.SetProp('FG', 'hydrazone') + atom.SetProp("FG", "hydrazone") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'hydrazone') + neighbor.SetProp("FG", "hydrazone") #### Imine #### - if charge == 0 and num_C == 1 and num_H == 1 and num_N == 0 and atom_num_neighbors == 1: # Primary ketimine [RC(=NH)R'] + if ( + charge == 0 + and num_C == 1 + and num_H == 1 + and num_N == 0 + and atom_num_neighbors == 1 + ): # Primary ketimine [RC(=NH)R'] for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'primary_ketimine') + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_ketimine") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'primary_ketimine') - - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'primary_aldimine') + neighbor.SetProp("FG", "primary_ketimine") + + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_aldimine") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'primary_aldimine') - - if charge == 0 and atom_num_neighbors == 1 and atom.GetProp('FG') not in ['thiocyanate', 'cyanate']: # Nitrile - for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE: - atom.SetProp('FG', 'nitrile') + neighbor.SetProp("FG", "primary_aldimine") - if charge == 0 and num_C >= 1 and atom_num_neighbors == 2 and atom.GetProp('FG') != 'hydrazone': # Secondary ketimine [RC(=NR'')R'] + if ( + charge == 0 + and atom_num_neighbors == 1 + and atom.GetProp("FG") not in ["thiocyanate", "cyanate"] + ): # Nitrile + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + ): + atom.SetProp("FG", "nitrile") + + if ( + charge == 0 + and num_C >= 1 + and atom_num_neighbors == 2 + and atom.GetProp("FG") != "hydrazone" + ): # Secondary ketimine [RC(=NR'')R'] for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 3 and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'secondary_ketimine') + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "secondary_ketimine") for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: - neighbor.SetProp('FG', 'secondary_ketimine') - - if neighbor.GetSymbol() in ['C', '*'] and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1: - atom.SetProp('FG', 'secondary_aldimine') + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_ketimine") + + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + ): + atom.SetProp("FG", "secondary_aldimine") for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: - neighbor.SetProp('FG', 'secondary_aldimine') - - - if charge == 1 and num_N == 2 and atom_num_neighbors == 2: # Azide [RN3] + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_aldimine") + + if ( + charge == 1 and num_N == 2 and atom_num_neighbors == 2 + ): # Azide [RN3] condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetFormalCharge() == 0 and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + if ( + neighbor.GetFormalCharge() == 0 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): condition1 = True - if neighbor.GetFormalCharge() == -1 and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + if ( + neighbor.GetFormalCharge() == -1 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): condition2 = True if condition1 and condition2 and not in_ring: - atom.SetProp('FG', 'azide') + atom.SetProp("FG", "azide") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'azide') - - if charge == 0 and num_N == 1 and atom_num_neighbors == 2 and not in_ring: # Azo [RN2R'] + neighbor.SetProp("FG", "azide") + + if ( + charge == 0 + and num_N == 1 + and atom_num_neighbors == 2 + and not in_ring + ): # Azo [RN2R'] for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'azo') - neighbor.SetProp('FG', 'azo') + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "azo") + neighbor.SetProp("FG", "azo") break - if charge == 1 and num_O == 3 and atom_num_neighbors == 3: # Nitrate [RONO2] + if ( + charge == 1 and num_O == 3 and atom_num_neighbors == 3 + ): # Nitrate [RONO2] condition1, condition2, condition3 = False, False, False for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + ): condition2 = True - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + ): condition3 = True - + if condition1 and condition2 and condition3 and not in_ring: - atom.SetProp('FG', 'nitrate') + atom.SetProp("FG", "nitrate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'nitrate') - - if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetFormalCharge() == -1 and len(neighbor.GetNeighbors()) == 1: - atom.SetProp('FG', 'isonitrile') - neighbor.SetProp('FG', 'isonitrile') + neighbor.SetProp("FG", "nitrate") - if charge == 0 and num_O == 2 and atom_num_neighbors == 2 and not in_ring: # Nitrite + if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == -1 + and len(neighbor.GetNeighbors()) == 1 + ): + atom.SetProp("FG", "isonitrile") + neighbor.SetProp("FG", "isonitrile") + + if ( + charge == 0 + and num_O == 2 + and atom_num_neighbors == 2 + and not in_ring + ): # Nitrite for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2: - atom.SetProp('FG', 'nitrosooxy') + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): + atom.SetProp("FG", "nitrosooxy") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'nitrosooxy') - - if charge == 1 and num_O == 2 and atom_num_neighbors == 3 and not in_ring: # Nitro compound + neighbor.SetProp("FG", "nitrosooxy") + + if ( + charge == 1 + and num_O == 2 + and atom_num_neighbors == 3 + and not in_ring + ): # Nitro compound condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if neighbor.GetSymbol() == "O": + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == -1: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + ): condition2 = True if condition1 and condition2 and not in_ring: - atom.SetProp('FG', 'nitro') + atom.SetProp("FG", "nitro") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'nitro') - - if charge == 0 and num_O == 1 and atom_num_neighbors == 2 and not in_ring: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "nitro") + + if ( + charge == 0 + and num_O == 1 + and atom_num_neighbors == 2 + and not in_ring + ): for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: # Nitroso compound - atom.SetProp('FG', 'nitroso') - neighbor.SetProp('FG', 'nitroso') - - if charge == 0 and num_O == 1 and num_C == 1 and atom_num_neighbors == 2: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): # Nitroso compound + atom.SetProp("FG", "nitroso") + neighbor.SetProp("FG", "nitroso") + + if ( + charge == 0 + and num_O == 1 + and num_C == 1 + and atom_num_neighbors == 2 + ): condition1, condition2, condition3 = False, False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + ): condition1 = True - if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): condition2 = True - if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 3: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + ): condition3 = True if condition1 and condition2 and not in_ring: - atom.SetProp('FG', 'aldoxime') + atom.SetProp("FG", "aldoxime") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'aldoxime') + neighbor.SetProp("FG", "aldoxime") if condition1 and condition3 and not in_ring: - atom.SetProp('FG', 'ketoxime') + atom.SetProp("FG", "ketoxime") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'ketoxime') + neighbor.SetProp("FG", "ketoxime") ########################### Groups containing sulfur ########################### - elif atom_symbol == 'S' and charge == 0: + elif atom_symbol == "S" and charge == 0: num_C, num_S, num_O = 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 - if neighbor.GetSymbol() == 'S': + if neighbor.GetSymbol() == "S": num_S += 1 - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 - if num_H == 1 and atom_num_neighbors == 1 and atom.GetProp('FG') not in ['carbothioic_S-acid', 'carbodithioic_acid']: + if ( + num_H == 1 + and atom_num_neighbors == 1 + and atom.GetProp("FG") + not in ["carbothioic_S-acid", "carbodithioic_acid"] + ): neighbor = atom_neighbors[0] - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: - atom.SetProp('FG', 'sulfhydryl') - - if num_H == 0 and atom_num_neighbors == 2 and atom.GetProp('FG') not in ['sulfhydrylester', 'carbodithio']: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + atom.SetProp("FG", "sulfhydryl") + + if ( + num_H == 0 + and atom_num_neighbors == 2 + and atom.GetProp("FG") not in ["sulfhydrylester", "carbodithio"] + ): cnt = 0 for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): cnt += 1 if cnt == 2: - atom.SetProp('FG', 'sulfide') - + atom.SetProp("FG", "sulfide") + if num_H == 0 and num_S == 1 and atom_num_neighbors == 2: condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and len(neighbor.GetNeighbors()) == 2: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): condition1 = True - if neighbor.GetSymbol() != 'S' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() != "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition2 = True if condition1 and condition2: - atom.SetProp('FG', 'disulfide') + atom.SetProp("FG", "disulfide") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'S': - neighbor.SetProp('FG', 'disulfide') - + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "disulfide") + if num_H == 0 and num_O >= 1 and atom_num_neighbors == 3: condition = False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition = True - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): cnt += 1 if condition and cnt == 2: - atom.SetProp('FG', 'sulfinyl') + atom.SetProp("FG", "sulfinyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'sulfinyl') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfinyl") + if num_H == 0 and num_O >= 2 and atom_num_neighbors == 4: cnt1 = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): cnt1 += 1 if cnt1 == 2: - atom.SetProp('FG', 'sulfonyl') + atom.SetProp("FG", "sulfonyl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: - neighbor.SetProp('FG', 'sulfonyl') - + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "sulfonyl") + if num_H == 0 and num_O == 2 and atom_num_neighbors == 3: condition1, condition2, condition3 = False, False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): condition2 = True - if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition3 = True if condition1 and condition2 and condition3 and not in_ring: - atom.SetProp('FG', 'sulfino') + atom.SetProp("FG", "sulfino") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'sulfino') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfino") + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): cnt += 1 - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition2 = True if condition1 and condition2 and cnt == 2 and not in_ring: - atom.SetProp('FG', 'sulfonic_acid') + atom.SetProp("FG", "sulfonic_acid") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'sulfonic_acid') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonic_acid") + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): cnt += 1 - if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): condition2 = True if condition1 and condition2 and cnt == 2: - atom.SetProp('FG', 'sulfonate_ester') + atom.SetProp("FG", "sulfonate_ester") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'sulfonate_ester') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonate_ester") + if num_H == 0 and atom_num_neighbors == 2: for neighbor in atom_neighbors: for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() == 'N' and mol.GetBondBetweenAtoms(C_neighbor.GetIdx(), neighbor.GetIdx()).GetBondType() == Chem.BondType.TRIPLE and not in_ring: - atom.SetProp('FG', 'thiocyanate') - neighbor.SetProp('FG', 'thiocyanate') - C_neighbor.SetProp('FG', 'thiocyanate') + if ( + C_neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + C_neighbor.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + and not in_ring + ): + atom.SetProp("FG", "thiocyanate") + neighbor.SetProp("FG", "thiocyanate") + C_neighbor.SetProp("FG", "thiocyanate") ########################### Groups containing phosphorus ########################### - elif atom_symbol == 'P' and not in_ring and charge == 0: + elif atom_symbol == "P" and not in_ring and charge == 0: num_C, num_O = 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 if atom_num_neighbors == 3: cnt = 0 for neighbor in atom_neighbors: - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): cnt += 1 if cnt == 3: - atom.SetProp('FG', 'phosphino') - + atom.SetProp("FG", "phosphino") + if num_O == 3 and atom_num_neighbors == 4: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): cnt += 1 - if neighbor.GetSymbol() != 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): condition2 = True if condition1 and condition2 and cnt == 2: - atom.SetProp('FG', 'phosphono') + atom.SetProp("FG", "phosphono") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'phosphono') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphono") + if num_O == 4 and atom_num_neighbors == 4: condition1 = False cnt1, cnt2 = 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): condition1 = True - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): cnt1 += 1 - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE and neighbor.GetTotalNumHs() == 0 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): cnt2 += 1 - + if condition1 and cnt1 == 2 and cnt2 == 1: - atom.SetProp('FG', 'phosphate') + atom.SetProp("FG", "phosphate") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'phosphate') + neighbor.SetProp("FG", "phosphate") if condition1 and cnt1 == 1 and cnt2 == 2: - atom.SetProp('FG', 'phosphodiester') + atom.SetProp("FG", "phosphodiester") for neighbor in atom_neighbors: - neighbor.SetProp('FG', 'phosphodiester') - + neighbor.SetProp("FG", "phosphodiester") + if num_O == 1 and atom_num_neighbors == 4: condition = False cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): condition = True - if mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() == Chem.BondType.SINGLE: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): cnt += 1 if condition and cnt == 3: - atom.SetProp('FG', 'phosphoryl') + atom.SetProp("FG", "phosphoryl") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'phosphoryl') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphoryl") + ########################### Groups containing boron ########################### - elif atom_symbol == 'B' and not in_ring and charge == 0: + elif atom_symbol == "B" and not in_ring and charge == 0: num_C, num_O = 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 - + if num_O == 2 and atom_num_neighbors == 3: cnt1, cnt2 = 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and neighbor.GetTotalNumHs() == 1 and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): cnt1 += 1 - if neighbor.GetSymbol() == 'O' and neighbor.GetFormalCharge() == 0 and len(neighbor.GetNeighbors()) == 2: + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): cnt2 += 1 if cnt1 == 2: - atom.SetProp('FG', 'borono') + atom.SetProp("FG", "borono") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'borono') + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "borono") if cnt2 == 2: - atom.SetProp('FG', 'boronate') + atom.SetProp("FG", "boronate") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': - neighbor.SetProp('FG', 'boronate') - + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "boronate") + if num_O == 1 and atom_num_neighbors == 3: for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and neighbor.GetFormalCharge() == 0: + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetFormalCharge() == 0 + ): if neighbor.GetTotalNumHs() == 1: - atom.SetProp('FG', 'borino') - neighbor.SetProp('FG', 'borino') + atom.SetProp("FG", "borino") + neighbor.SetProp("FG", "borino") if len(neighbor.GetNeighbors()) == 2: - atom.SetProp('FG', 'borinate') - neighbor.SetProp('FG', 'borinate') - + atom.SetProp("FG", "borinate") + neighbor.SetProp("FG", "borinate") + ########################### Groups containing silicon ########################### - elif atom_symbol =='Si' and not in_ring and charge == 0: + elif atom_symbol == "Si" and not in_ring and charge == 0: num_O, num_Cl, num_C = 0, 0, 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O': + if neighbor.GetSymbol() == "O": num_O += 1 - if neighbor.GetSymbol() == 'Cl': + if neighbor.GetSymbol() == "Cl": num_Cl += 1 - if neighbor.GetSymbol() in ['C', '*']: + if neighbor.GetSymbol() in ["C", "*"]: num_C += 1 if num_O == 1 and charge == 0 and atom_num_neighbors == 4: for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'O' and len(neighbor.GetNeighbors()) == 2 and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'silyl_ether') - neighbor.SetProp('FG', 'silyl_ether') + if ( + neighbor.GetSymbol() == "O" + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "silyl_ether") + neighbor.SetProp("FG", "silyl_ether") if num_Cl == 2 and charge == 0 and atom_num_neighbors == 4: for neighbor in atom_neighbors: - if neighbor.GetSymbol() == 'Cl' and neighbor.GetFormalCharge() == 0: - atom.SetProp('FG', 'dichlorosilane') - neighbor.SetProp('FG', 'dichlorosilane') - if num_C >= 3 and charge == 0 and atom_num_neighbors == 4 and atom.GetProp('FG') != 'silyl_ether': + if ( + neighbor.GetSymbol() == "Cl" + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "dichlorosilane") + neighbor.SetProp("FG", "dichlorosilane") + if ( + num_C >= 3 + and charge == 0 + and atom_num_neighbors == 4 + and atom.GetProp("FG") != "silyl_ether" + ): cnt = 0 C_idx = [] for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ['C', '*'] and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 3: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 3 + ): cnt += 1 C_idx.append(neighbor.GetIdx()) if cnt == 3: - atom.SetProp('FG', 'trimethylsilyl') + atom.SetProp("FG", "trimethylsilyl") for idx in C_idx: - mol.GetAtomWithIdx(idx).SetProp('FG', 'trimethylsilyl') - + mol.GetAtomWithIdx(idx).SetProp("FG", "trimethylsilyl") ########################### Groups containing halogen ########################### - elif atom_symbol == 'F' and not in_ring and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'fluoro') - elif atom_symbol == 'Cl' and not in_ring and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'chloro') - elif atom_symbol == 'Br' and not in_ring and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'bromo') - elif atom_symbol == 'I' and not in_ring and charge == 0 and atom.GetProp('FG') == '': - atom.SetProp('FG', 'iodo') + elif ( + atom_symbol == "F" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "fluoro") + elif ( + atom_symbol == "Cl" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "chloro") + elif ( + atom_symbol == "Br" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "bromo") + elif ( + atom_symbol == "I" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "iodo") else: pass ########################### Groups containing other elements ########################### - if atom.GetProp('FG') == '' and atom_symbol in ELEMENTS and not in_ring: + if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: if charge == 0: - atom.SetProp('FG', atom_symbol) + atom.SetProp("FG", atom_symbol) else: - atom.SetProp('FG', f'{atom_symbol}[{charge}]') + atom.SetProp("FG", f"{atom_symbol}[{charge}]") else: pass - - if atom_symbol == '*': - atom.SetProp('FG', '') - -test_case = { - 'hydroxyl': 'CCCCO', - 'ether': 'CCCCOC', - 'peroxy': 'CCCCOOCCCC', - 'hydroperoxy': 'CCCCCCCOO', - 'haloformyl': 'CCCCC(=O)F', - 'ketone': 'CCCC(=O)CCCC', - 'aldehyde': 'CCC(=O)', - 'carboxylate': 'CCCCC(=O)[O-]', - 'carboxyl': 'CCCC(=O)O', - 'ester': 'CC(=O)OCCCCC', - 'hemiketal': 'CCCC(OC)(O)CCC', - 'ketal': 'CCCC(OCCC)(OCC)CCC', - 'carbonate_ester': 'C(=O)(OC(Cl)(Cl)Cl)OC(Cl)(Cl)Cl', - 'hemiacetal': 'CCCCC(OCCCC)(O)', - 'acetal': 'CCCCC(OCCC)(OCCC)', - 'orthoester': 'CC(OC)(OC)(OC)', - 'orthocarbonate_ester': 'C(OCCCC)(OCC)(OCC)(OCC)', - 'carboxylic_anhydride': 'C1CCC(CC1)C(=O)OC(=O)C2CCCCC2', - 'primary_amine': 'CCCCCCN', - 'secondary_amine': 'CCCCCCNCCC', - 'tertiary_amine': 'CCCCCCN(CCC)CCC', - '4_ammonium_ion': 'CCCCCC[N+](CC)(CCC)CCC', - 'hydrazone': 'CCCC(CCC)=NN', - 'primary_ketimine': 'CCCC(=N)CC', - 'secondary_ketimine': 'CCCC(=NCCC)CC', - 'primary_aldimine': 'CCCC(=N)', - 'secondary_aldimine': 'CCCC=NCCCC', - 'imide': 'CCC(=O)N(CCCC)C(=O)CCC', - 'amide': 'CCCC(=O)N(CCC)CCCCC', - 'amidine': 'CCCN=C(CC)N(CCCCC)CCC', - 'azide': 'C1=CC=C(C=C1)N=[N+]=[N-]', - 'azo': 'CN(C)C1=CC=C(C=C1)N=NC2=CC=C(C=C2)S(=O)(=O)[O-]', - 'cyanate': 'c1ccccc1COC#N', - 'isocyanate': 'CCCN=C=O', - 'nitrate': 'CCCCCO[N+](=O)[O-]', - 'nitrile': 'CCC#N', - 'isonitrile': 'CC[N+]#[C-]', - 'nitrosooxy': 'CC(C)CCON=O', - 'nitro': 'C[N+](=O)[O-]', - 'nitroso': 'C1=CC=C(C=C1)N=O', - 'aldoxime': 'CCCC=NO', - 'ketoxime': 'CCC(CCC)=NO', - 'carbamate': 'CC(C)OC(=O)N(CCC)C1=CC(=CC=C1)Cl', - 'sulfhydryl': 'CCCCCS', - 'sulfide': 'CSC', - 'disulfide': 'CSSC', - 'sulfinyl': 'CS(=O)C', - 'sulfonyl': 'CCCS(=O)(=O)CCCC', - 'sulfino': 'CCCCS(=O)O', - 'sulfonic_acid': 'CCCCS(=O)(=O)O', - 'sulfonate_ester': 'CCCS(=O)(=O)OCCCCC', - 'thiocyanate': 'CCCCSC#N', - 'isothiocyanate': 'c1ccccc1N=C=S', - 'thioketone': 'CCC(=S)CCCC', - 'thial': 'CCCC=S', - 'carbothioic_S-acid': 'CCC(=O)S', - 'carbothioic_O-acid': 'CCC(=S)O', - 'thiolester': 'CCC(=O)SCCC', - 'thionoester':'CCC(=S)OCCC', - 'carbodithioic_acid': 'CCCC(=S)S', - 'carbodithio': 'CCCC(=S)SCC', - 'phosphino': 'CCCCP(CCCC)CCCC', - 'phosphono': 'CCCP(=O)(O)O', - 'phosphate': 'CCCOP(=O)(O)O', - 'phosphodiester': 'CCCOP(=O)(O)OCCC', - 'phosphoryl': 'CCCP(=O)(CCC)CCC', - 'borono': 'c1ccccc1B(O)O', - 'boronate': 'CCCB(OCC)OCCC', - 'borino': 'CCCB(CCCC)O', - 'borinate': 'CCCB(CCCC)OCCC', - 'silyl_ether': 'C[Si](C)(C)OS(=O)(=O)C(F)(F)F', - 'dichlorosilane': 'CCCC[Si](Cl)(Cl)CCCC', - 'trimethylsilyl': 'CCCC[Si](C)(C)C', - 'fluoro': 'CF', - 'chloro': 'CCCCl', - 'bromo': 'CBr', - 'iodo': 'CCCI', - 'trifluoromethyl': 'CCCC(F)(F)F', - 'difluorochloromethyl': 'CCC(F)(F)Cl', - 'bromodifluoromethyl': 'CCC(F)(F)Br', - 'trichloromethyl': 'CCC(Cl)(Cl)Cl', - 'bromodichloromethyl': 'CCC(Cl)(Cl)Br', - 'tribromomethyl': 'CCC(Br)(Br)Br', - 'dibromofluoromethyl': 'CCCC(F)(Br)Br', - 'triiodomethyl': 'CCC(I)(I)I', - 'difluoromethyl': 'CCC(F)F', - 'fluorochloromethyl': 'CCC(F)Cl', - 'dichloromethyl': 'CCCC(Cl)Cl', - 'chlorobromomethyl': 'CCCC(Cl)Br', - 'chloroiodomethyl': 'CCCC(Cl)I', - 'dibromomethyl': 'CCCCC(Br)Br', - 'bromoiodomethyl': 'CCCC(Br)I', - 'diiodomethyl': 'CCCCC(I)I' -} -def has_ring(mol): - if mol.GetRingInfo().NumRings() > 0: - return True - else: - return False + if atom_symbol == "*": + atom.SetProp("FG", "") -def ring_separation(mol): - AllChem.GetSymmSSSR(mol) # type: ignore - rings = mol.GetRingInfo().AtomRings() - splitting_bonds = [] - if mol is not None: - for bond in mol.GetBonds(): - begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() - if (begin_atom.IsInRing() and not end_atom.IsInRing()) or (not begin_atom.IsInRing() and end_atom.IsInRing()) or (begin_atom.IsInRing() and end_atom.IsInRing()): - flag = True - for ring in rings: - if begin_atom.GetIdx() in ring and end_atom.GetIdx() in ring: - flag = False - break - if flag: - splitting_bonds.append(bond) - - if len(splitting_bonds) > 0: - fragments = FragmentOnBonds(mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True) - SMILES = m2s(fragments).split('.') - SMILES = [re.sub(r'\[\d+\*\]', '[*]', i) for i in SMILES] - SMILES = [m2s(s2m(i)) for i in SMILES] - return SMILES - else: - return None - def set_atom_map_num(mol): if mol is not None: for atom in mol.GetAtoms(): @@ -1262,87 +2051,39 @@ def set_atom_map_num(mol): atom.SetAtomMapNum(idx) else: atom.SetAtomMapNum(-9) - + + def find_neighbor_map(smiles): - matches = re.findall(r'\[(\d+)\*\]', smiles) + matches = re.findall(r"\[(\d+)\*\]", smiles) idx = [int(match) for match in matches] - if smiles.startswith('*'): + if smiles.startswith("*"): return set(idx) | {0} else: return set(idx) - + + def find_atom_map(smiles): - matches = re.findall(r'\[[^\]]*:(\d+)\]', smiles) + matches = re.findall(r"\[[^\]]*:(\d+)\]", smiles) idx = [int(match) for match in matches] return set(idx) -def get_scaffold(mol): - scaffold = MurckoScaffold.GetScaffoldForMol(mol) - AllChem.GetSymmSSSR(scaffold) # type: ignore - rings = scaffold.GetRingInfo().AtomRings() - - editable_mol = Chem.EditableMol(scaffold) # type: ignore - delete_idx = set() - for bond in scaffold.GetBonds(): - begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() - if bond.GetBondType() == Chem.BondType.DOUBLE: # type: ignore - if begin_atom.IsInRing() and not end_atom.IsInRing(): - delete_idx.add(end_atom.GetIdx()) - if not begin_atom.IsInRing() and end_atom.IsInRing(): - delete_idx.add(begin_atom.GetIdx()) - if len(begin_atom.GetNeighbors()) == 1: - delete_idx.add(begin_atom.GetIdx()) - if len(end_atom.GetNeighbors()) == 1: - delete_idx.add(end_atom.GetIdx()) - - if scaffold is not None: - for atom in scaffold.GetAtoms(): - bonds = atom.GetBonds() - cnt = 0 - for bond in bonds: - if bond.GetBondType() == Chem.BondType.SINGLE: # type: ignore - cnt += 1 - if not atom.IsInRing() and cnt > 2: - flag = False - for bond in bonds: - begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() - if begin_atom.IsInRing(): - flag = True - for ring in rings: - if begin_atom.GetIdx() in ring: - delete_idx.add(begin_atom.GetIdx()) - if end_atom.IsInRing(): - flag = True - for ring in rings: - if end_atom.GetIdx() in ring: - for r in ring: - delete_idx.add(r) - if flag: - break - - - delete_idx = list(delete_idx) - delete_idx.sort(reverse=True) - for atom_idx in delete_idx: - editable_mol.RemoveAtom(atom_idx) - - return editable_mol.GetMol() - def get_structure(mol): - set_atom_map_num(mol) - detect_functional_group(mol) rings = mol.GetRingInfo().AtomRings() splitting_bonds = set() for bond in mol.GetBonds(): begin_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom() - begin_atom_prop = begin_atom.GetProp('FG') - end_atom_prop = end_atom.GetProp('FG') + begin_atom_prop = begin_atom.GetProp("FG") + end_atom_prop = end_atom.GetProp("FG") begin_atom_symbol = begin_atom.GetSymbol() end_atom_symbol = end_atom.GetSymbol() - if (begin_atom.IsInRing() and not end_atom.IsInRing()) or (not begin_atom.IsInRing() and end_atom.IsInRing()) or (begin_atom.IsInRing() and end_atom.IsInRing()): + if ( + (begin_atom.IsInRing() and not end_atom.IsInRing()) + or (not begin_atom.IsInRing() and end_atom.IsInRing()) + or (begin_atom.IsInRing() and end_atom.IsInRing()) + ): flag = True for ring in rings: if begin_atom.GetIdx() in ring and end_atom.GetIdx() in ring: @@ -1353,126 +2094,44 @@ def get_structure(mol): else: if begin_atom_prop != end_atom_prop: splitting_bonds.add(bond) - if begin_atom_prop == '' and end_atom_prop == '': - if (begin_atom_symbol in ['C', '*'] and end_atom_symbol != 'C') or (begin_atom_symbol != 'C' and end_atom_symbol in ['C', '*']): + if begin_atom_prop == "" and end_atom_prop == "": + if (begin_atom_symbol in ["C", "*"] and end_atom_symbol != "C") or ( + begin_atom_symbol != "C" and end_atom_symbol in ["C", "*"] + ): splitting_bonds.add(bond) splitting_bonds = list(splitting_bonds) if splitting_bonds != []: - fragments = Chem.FragmentOnBonds(mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True) + fragments = Chem.FragmentOnBonds( + mol, [bond.GetIdx() for bond in splitting_bonds], addDummies=True + ) BONDS = set() for bond in splitting_bonds: - BONDS.add((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType())) + BONDS.add( + (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), bond.GetBondType()) + ) else: fragments = mol BONDS = set() - smiles = m2s(fragments).replace('-9', '0').split('.') + smiles = m2s(fragments).replace("-9", "0").split(".") structure = {} for frag in smiles: atom_idx, neighbor_idx = set(), set() atom_idx = find_atom_map(frag) neighbor_idx = find_neighbor_map(frag) - structure[frag] = {'atom': atom_idx, 'neighbor': neighbor_idx} - - return structure, BONDS - -def preprocess_smiles(smiles): - mol = s2m(smiles) - if mol is not None: - for atom in mol.GetAtoms(): - atom.SetAtomMapNum(0) - smiles = m2s(mol) - smiles = re.sub(r'\[\d+\*\]', '[*]', smiles) - smiles = m2s(s2m(smiles)) - return smiles - -def preprocess_mol(mol): - MOL = deepcopy(mol) - if mol is not None: - for atom in mol.GetAtoms(): - if atom.GetSymbol() == '*': - atom.SetAtomicNum(1) - mol_ = s2m(m2s(MOL)) - if mol_ is None: - print(m2s(mol)) - for atom in MOL.GetAtoms(): - if atom.GetSymbol() == '*': - atom.SetAtomicNum(6) - mol_ = s2m(m2s(MOL)) - else: - del MOL - return mol_ - -def remove_wildcards(mol): - editable_mol = Chem.EditableMol(mol) - wildcard_indices = [atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomicNum() == 0] - for idx in sorted(wildcard_indices, reverse=True): - editable_mol.RemoveAtom(idx) - return editable_mol.GetMol() - -def get_ring_structure(mol): - for atom in mol.GetAtoms(): - if atom.GetAtomicNum() != 6: - atom.SetAtomicNum(6) - for bond in mol.GetBonds(): - if bond.GetIsAromatic(): - bond.SetIsAromatic(False) - if bond.GetBondType() != Chem.BondType.SINGLE: - bond.SetBondType(Chem.BondType.SINGLE) - return mol - -def get_core_structure(mol): - for atom in mol.GetAtoms(): - if atom.GetAtomicNum() != 6: - atom.SetAtomicNum(6) - for bond in mol.GetBonds(): - if bond.GetBondType() != Chem.BondType.SINGLE: - bond.SetBondType(Chem.BondType.SINGLE) - return mol + structure[frag] = {"atom": atom_idx, "neighbor": neighbor_idx} -def get_new_smiles_rep(mol): - def replace_pattern(match): - number = int(match.group(1)) - return feature_idx.get(number, f"UNK") + return structure, BONDS - if mol is not None: - detect_functional_group(mol) - feature_idx = dict() - for atom in mol.GetAtoms(): - idx = atom.GetIdx() - if idx == 0: - idx = -9 - atom.SetAtomMapNum(idx) - - symbol = atom.GetSymbol() - if atom.GetIsAromatic(): - symbol = symbol.lower() - fg = atom.GetProp('FG') if atom.HasProp('FG') else '' - ring = atom.GetProp('RING') if atom.HasProp('RING') else '' - - if fg != '' and ring != '': - feature = symbol + '_' + fg + '_' + ring - elif fg != '' and ring == '': - feature = symbol + '_' + fg - elif fg == '' and ring != '': - feature = symbol + '_' + ring - else: - feature = symbol - feature_idx[idx] = ' ' + feature + ' ' - - smiles = m2s(mol) - feature_idx[0] = feature_idx[-9] - smiles = smiles.replace('-9', '0') - smiles = re.sub(r'\[.*?:(\d+)\]', replace_pattern, smiles) - smiles = re.sub(r'\s+', ' ', smiles) - smiles_list = [] - for t in smiles.split(' '): - if '_' not in t and len(t) > 1: - smiles_list.extend([char + ' ' for char in t]) - else: - smiles_list.append(t + ' ') +if __name__ == "__main__": + from rdkit.Chem import MolFromSmiles as s2m - new_smiles = ''.join(smiles_list).strip() - return new_smiles \ No newline at end of file + SMILES = ( + "CCOc1c(OC)cccc1[C@@H]1C(C(=O)OCCOC)=C(C)N=c2s/c(=C/c3cccc(OCC#N)c3)c(=O)n21" + ) + mol = s2m(SMILES) + # set_atom_map_num(mol) + get_structure(mol) + print(m2s(mol)) From 5b4db355770bde910ac911b957acbe13ea43f95e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 19 Apr 2025 16:16:01 +0200 Subject: [PATCH 005/224] reader: add fg augment reader --- chebai_graph/preprocessing/reader.py | 203 +++++++++++++++++++++++++-- 1 file changed, 192 insertions(+), 11 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index b814d53..8cbbab5 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,24 @@ import importlib +import os +from typing import List, Mapping, Optional, Tuple -from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List - -import importlib +import chebai.preprocessing.reader as dr import networkx as nx -import os -import torch -import rdkit.Chem as Chem import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator -import chebai_graph.preprocessing.properties as properties +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn +from rdkit.Chem import Mol from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from torch_geometric.utils import from_networkx + +import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing.collate import GraphCollator +from chebai_graph.preprocessing.fg_detection.rule_based import ( + detect_functional_group, + get_structure, + set_atom_map_num, +) class GraphPropertyReader(dr.ChemDataReader): @@ -133,3 +138,179 @@ def _read_data(self, raw_data) -> Optional[GeomData]: def collate(self, list_of_tuples): return self.collator(list_of_tuples) + + +class GraphFGAugmentorReader(dr.ChemDataReader): + COLLATOR = GraphCollator + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.failed_counter = 0 + self.mol_object_buffer = {} + self.node_level = {"atom_node": 1, "fg_node": 2, "graph_node": 3} + self.edge_level = { + "within_atoms": 1, + "within_fg": 2, + "atom_fg": 3, + "fg_graphNode": 4, + } + + @classmethod + def name(cls): + return "graph_fg_augmentor" + + def _read_data(self, raw_data): + mol = self._smiles_to_mol(raw_data) + if mol is None: + return None + + x = torch.zeros((mol.GetNumAtoms(), 0)) + + edge_attr = torch.zeros((mol.GetNumBonds(), 0)) + + edge_index = self._augment_graph(mol) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: + """Load smiles into rdkit, store object in buffer""" + if smiles in self.mol_object_buffer: + return self.mol_object_buffer[smiles] + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") + self.failed_counter += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") + self.failed_counter += 1 + self.mol_object_buffer[smiles] = mol + return mol + + def _augment_graph(self, mol: Mol): + edge_index = torch.tensor( + [ + [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], + [bond.GetEndAtomIdx() for bond in mol.GetBonds()], + ] + ) + within_atoms_edge_index = torch.cat([edge_index, edge_index[[1, 0], :]], dim=1) + + num_of_nodes = mol.GetNumAtoms() + + set_atom_map_num(mol) + detect_functional_group(mol) + + node_features = [] + sorted_atoms = sorted( + list(mol.GetAtoms()), key=lambda atom: atom.GetAtomMapNum() + ) + for idx, atom in enumerate(sorted_atoms): + node_features.append( + [self.node_level["atom_node"], self._get_fg_index(atom)] + ) + + structure, bonds = get_structure(mol) + + if not structure: + raise ValueError("") + + # Preprocess the molecular structure to match feature dictionary keys + fg_to_atoms_edge_index = [[], []] + new_structure = {} + for idx, fg in enumerate(structure): + # new_sm = preprocess_smiles(sm) # Preprocess SMILES to match the feature dictionary + new_structure[num_of_nodes] = { + "atom": structure[fg]["atom"] # Get atom list for fragment + } + for atom in structure[fg]["atom"]: + fg_to_atoms_edge_index[0].extend([num_of_nodes, atom]) + fg_to_atoms_edge_index[1].extend([atom, num_of_nodes]) + + node_features.append( + [ + self.node_level["fg_node"], + self._get_fg_index(next(iter(structure[fg]["atom"][0]))), + ] + ) + + num_of_nodes += 1 + + within_fg_edge_index = [[], []] + for bond in bonds: + start_idx, end_idx = bond[:2] + for key, value in new_structure.items(): + if start_idx in value["atom"]: + source_fg = key + if end_idx in value["atom"]: + target_fg = key + within_fg_edge_index[0].extend([source_fg, target_fg]) + within_fg_edge_index[1].extend([target_fg, source_fg]) + + node_features.append( + [self.node_level["global_node"], self._get_token_index("graph_fg")] + ) + global_node_edge_index = [[], []] + for fg in new_structure.keys(): + global_node_edge_index[0].extend([num_of_nodes, fg]) + global_node_edge_index[1].extend([fg, num_of_nodes]) + + all_edges = torch.cat( + [ + within_atoms_edge_index, + torch.tensor(fg_to_atoms_edge_index, dtype=torch.long), + torch.tensor(within_fg_edge_index, dtype=torch.long), + torch.tensor(global_node_edge_index, dtype=torch.long), + ], + dim=1, + ) + + return all_edges + + def _get_fg_index(self, atom): + fg_group = atom.GetProp("FG") + if fg_group: + fg_index = self._get_token_index(fg_group) + return fg_index + else: + raise Exception("") + + def on_finish(self): + rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") + self.mol_object_buffer = {} + + def read_property( + self, smiles: str, property: properties.MolecularProperty + ) -> Optional[List]: + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + return property.get_property_value(mol) + + +if __name__ == "__main__": + import matplotlib + import matplotlib.pyplot as plt + + matplotlib.use("TkAgg") # or 'Qt5Agg' + import networkx as nx + from torch_geometric.utils import to_networkx + + gr = GraphFGAugmentorReader() + SMILES = "CC(=O)Oc1ccccc1C(O)=O" + data_obj = gr._read_data(SMILES) + # Convert GeomData to NetworkX graph + G = to_networkx(data_obj, to_undirected=True) # optional: directed=False + + # Plot it + plt.figure(figsize=(4, 4)) + nx.draw(G, with_labels=True, node_color="skyblue", node_size=700, edge_color="gray") + plt.title("Molecular Graph") + plt.show() From 4af345db13d599e778c093c927b24167b1ce5510 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 19 Apr 2025 16:26:25 +0200 Subject: [PATCH 006/224] reader: add check for graph functional group --- chebai_graph/preprocessing/reader.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 8cbbab5..c943739 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -142,6 +142,13 @@ def collate(self, list_of_tuples): class GraphFGAugmentorReader(dr.ChemDataReader): COLLATOR = GraphCollator + NODE_LEVEL = {"atom_node": 1, "fg_node": 2, "graph_node": 3} + EDGE_LEVEL = { + "within_atoms": 1, + "within_fg": 2, + "atom_fg": 3, + "fg_graphNode": 4, + } def __init__( self, @@ -151,13 +158,12 @@ def __init__( super().__init__(*args, **kwargs) self.failed_counter = 0 self.mol_object_buffer = {} - self.node_level = {"atom_node": 1, "fg_node": 2, "graph_node": 3} - self.edge_level = { - "within_atoms": 1, - "within_fg": 2, - "atom_fg": 3, - "fg_graphNode": 4, - } + + if "graph_fg" not in self.cache: + raise KeyError( + f"Function group `graph_fg` doesn't exits in {self.token_path}. " + f"It should be manually added to token file (preferably at 0th index)" + ) @classmethod def name(cls): @@ -214,7 +220,7 @@ def _augment_graph(self, mol: Mol): ) for idx, atom in enumerate(sorted_atoms): node_features.append( - [self.node_level["atom_node"], self._get_fg_index(atom)] + [self.NODE_LEVEL["atom_node"], self._get_fg_index(atom)] ) structure, bonds = get_structure(mol) @@ -236,7 +242,7 @@ def _augment_graph(self, mol: Mol): node_features.append( [ - self.node_level["fg_node"], + self.NODE_LEVEL["fg_node"], self._get_fg_index(next(iter(structure[fg]["atom"][0]))), ] ) @@ -255,7 +261,7 @@ def _augment_graph(self, mol: Mol): within_fg_edge_index[1].extend([target_fg, source_fg]) node_features.append( - [self.node_level["global_node"], self._get_token_index("graph_fg")] + [self.NODE_LEVEL["global_node"], self._get_token_index("graph_fg")] ) global_node_edge_index = [[], []] for fg in new_structure.keys(): From 2381686830d05079b602bf04626cdc05bbd59966 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 19 Apr 2025 20:03:57 +0200 Subject: [PATCH 007/224] reader: add ring size to node feature --- chebai_graph/preprocessing/reader.py | 43 ++++++++++++---------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index c943739..d1c214f 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -218,9 +218,14 @@ def _augment_graph(self, mol: Mol): sorted_atoms = sorted( list(mol.GetAtoms()), key=lambda atom: atom.GetAtomMapNum() ) + for idx, atom in enumerate(sorted_atoms): node_features.append( - [self.NODE_LEVEL["atom_node"], self._get_fg_index(atom)] + [ + self.NODE_LEVEL["atom_node"], + self._get_fg_index(atom), + self._get_ring_size(atom), + ] ) structure, bonds = get_structure(mol) @@ -240,10 +245,12 @@ def _augment_graph(self, mol: Mol): fg_to_atoms_edge_index[0].extend([num_of_nodes, atom]) fg_to_atoms_edge_index[1].extend([atom, num_of_nodes]) + any_atom = next(iter(structure[fg]["atom"][0])) # any atom related to fg node_features.append( [ self.NODE_LEVEL["fg_node"], - self._get_fg_index(next(iter(structure[fg]["atom"][0]))), + self._get_fg_index(any_atom), + self._get_ring_size(any_atom), ] ) @@ -261,7 +268,7 @@ def _augment_graph(self, mol: Mol): within_fg_edge_index[1].extend([target_fg, source_fg]) node_features.append( - [self.NODE_LEVEL["global_node"], self._get_token_index("graph_fg")] + [self.NODE_LEVEL["global_node"], self._get_token_index("graph_fg"), 0] ) global_node_edge_index = [[], []] for fg in new_structure.keys(): @@ -288,6 +295,15 @@ def _get_fg_index(self, atom): else: raise Exception("") + def _get_ring_size(self, atom): + ring_size_str = atom.GetProp("RING") + if ring_size_str: + ring_sizes = list(map(int, ring_size_str.split("-"))) + # TODO: Decide ring size for atoms belongs to fused rings, rn only max ring size taken + return max(ring_sizes) + else: + return 0 + def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} @@ -299,24 +315,3 @@ def read_property( if mol is None: return None return property.get_property_value(mol) - - -if __name__ == "__main__": - import matplotlib - import matplotlib.pyplot as plt - - matplotlib.use("TkAgg") # or 'Qt5Agg' - import networkx as nx - from torch_geometric.utils import to_networkx - - gr = GraphFGAugmentorReader() - SMILES = "CC(=O)Oc1ccccc1C(O)=O" - data_obj = gr._read_data(SMILES) - # Convert GeomData to NetworkX graph - G = to_networkx(data_obj, to_undirected=True) # optional: directed=False - - # Plot it - plt.figure(figsize=(4, 4)) - nx.draw(G, with_labels=True, node_color="skyblue", node_size=700, edge_color="gray") - plt.title("Molecular Graph") - plt.show() From f596677a64b4d3c656e6340c2f0183f53baffd4b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 21 Apr 2025 16:46:45 +0200 Subject: [PATCH 008/224] update logic for node features to be processed by properties --- chebai_graph/preprocessing/properties.py | 167 ++++++++++++++++++++++- chebai_graph/preprocessing/reader.py | 119 ++++++++++------ 2 files changed, 237 insertions(+), 49 deletions(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 95f85ab..15f2e61 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -1,16 +1,22 @@ import abc -from typing import Optional +from typing import Dict, Optional import numpy as np import rdkit.Chem as Chem from descriptastorus.descriptors import rdNormalizedDescriptors +from rdkit.Chem import Mol +from chebai_graph.preprocessing.fg_detection.fg_constants import ( + EDGE_LEVELS, + NODE_LEVEL, + WITHIN_ATOMS_EDGE, +) from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) @@ -36,7 +42,7 @@ def get_property_value(self, mol: Chem.rdchem.Mol): raise NotImplementedError -class AtomProperty(MolecularProperty): +class AtomProperty(MolecularProperty, abc): """Property of an atom.""" def get_property_value(self, mol: Chem.rdchem.Mol): @@ -46,6 +52,66 @@ def get_atom_value(self, atom: Chem.rdchem.Atom): return NotImplementedError +class AugmentedAtomProperty(MolecularProperty, abc): + MAIN_KEY = "nodes" + + def get_property_value(self, augmented_mol: Dict): + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = {"atom_nodes", "fg_nodes", "graph_node"} - augmented_mol[ + self.MAIN_KEY + ].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + if not isinstance(atom_molecule, Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + ) + + prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] + if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) + + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) + prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) + + return prop_list + + @abc.abstractmethod + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + pass + + def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + value = self._get_atom_prop_value(atom, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + @staticmethod + def _get_atom_prop_value(atom: Chem.rdchem.Atom | Dict, prop: str): + if isinstance(atom, Chem.rdchem.Atom): + return atom.GetProp(prop) + elif isinstance(atom, dict): + return atom[prop] + else: + raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") + + class BondProperty(MolecularProperty): def get_property_value(self, mol: Chem.rdchem.Mol): return [self.get_bond_value(bond) for bond in mol.GetBonds()] @@ -54,6 +120,64 @@ def get_bond_value(self, bond: Chem.rdchem.Bond): return NotImplementedError +class AugmentedBondProperty(MolecularProperty, abc): + MAIN_KEY = "edges" + + def get_property_value(self, augmented_mol: Dict): + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] + if not isinstance(atom_molecule, Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + ) + + prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] + if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) + + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) + prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) + + return prop_list + + @abc.abstractmethod + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + pass + + def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + value = self._get_atom_prop_value(atom, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + @staticmethod + def _get_atom_prop_value(atom: Chem.rdchem.Atom | Dict, prop: str): + if isinstance(atom, Chem.rdchem.Atom): + return atom.GetProp(prop) + elif isinstance(atom, dict): + return atom[prop] + else: + raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") + + class MoleculeProperty(MolecularProperty): """Global property of a molecule.""" @@ -114,6 +238,39 @@ def get_atom_value(self, atom: Chem.rdchem.Atom): return atom.GetIsAromatic() +class AtomNodeLevel(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + return self._check_modify_atom_prop_value(atom, NODE_LEVEL) + + +class AtomFunctionalGroup(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + return self._check_modify_atom_prop_value(atom, "FG") + + +class AtomRingSize(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + return self._check_modify_atom_prop_value(atom, "RING") + + def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + ring_size_str = self._get_atom_prop_value(atom, prop) + if ring_size_str: + ring_sizes = list(map(int, ring_size_str.split("-"))) + # TODO: Decide ring size for atoms belongs to fused rings, rn only max ring size taken + return max(ring_sizes) + else: + return 0 + + class BondAromaticity(BondProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or BoolEncoder(self)) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index d1c214f..5e915f0 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,6 +1,5 @@ -import importlib import os -from typing import List, Mapping, Optional, Tuple +from typing import List, Optional import chebai.preprocessing.reader as dr import networkx as nx @@ -14,6 +13,17 @@ import chebai_graph.preprocessing.properties as properties from chebai_graph.preprocessing.collate import GraphCollator +from chebai_graph.preprocessing.fg_detection.fg_constants import ( + ATOM_FG_EDGE, + ATOM_NODE_LEVEL, + EDGE_LEVEL, + FG_GRAPHNODE_LEVEL, + FG_NODE_LEVEL, + GRAPH_NODE_LEVEL, + NODE_LEVEL, + WITHIN_ATOMS_EDGE, + WITHIN_FG_EDGE, +) from chebai_graph.preprocessing.fg_detection.rule_based import ( detect_functional_group, get_structure, @@ -142,13 +152,6 @@ def collate(self, list_of_tuples): class GraphFGAugmentorReader(dr.ChemDataReader): COLLATOR = GraphCollator - NODE_LEVEL = {"atom_node": 1, "fg_node": 2, "graph_node": 3} - EDGE_LEVEL = { - "within_atoms": 1, - "within_fg": 2, - "atom_fg": 3, - "fg_graphNode": 4, - } def __init__( self, @@ -170,23 +173,28 @@ def name(cls): return "graph_fg_augmentor" def _read_data(self, raw_data): - mol = self._smiles_to_mol(raw_data) + augmented_mol, edge_index = self._get_augmented_molecule(raw_data) + + x = torch.zeros((augmented_mol["nodes"]["num_nodes"], 0)) + edge_attr = torch.zeros((augmented_mol["nodes"]["num_edges"], 0)) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def _get_augmented_molecule(self, smiles): + mol = self._smiles_to_mol(smiles) if mol is None: return None - x = torch.zeros((mol.GetNumAtoms(), 0)) + edge_index, augmented_graph_nodes, augmented_graph_edges = self._augment_graph( + mol + ) - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) + augmented_mol = {"nodes": augmented_graph_nodes, "edges": augmented_graph_edges} + self.mol_object_buffer[smiles] = augmented_mol - edge_index = self._augment_graph(mol) - - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + return augmented_mol, edge_index def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: - """Load smiles into rdkit, store object in buffer""" - if smiles in self.mol_object_buffer: - return self.mol_object_buffer[smiles] - mol = Chem.MolFromSmiles(smiles) if mol is None: rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") @@ -195,9 +203,8 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") + rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, Error {e}") self.failed_counter += 1 - self.mol_object_buffer[smiles] = mol return mol def _augment_graph(self, mol: Mol): @@ -210,23 +217,16 @@ def _augment_graph(self, mol: Mol): within_atoms_edge_index = torch.cat([edge_index, edge_index[[1, 0], :]], dim=1) num_of_nodes = mol.GetNumAtoms() + num_of_edges = mol.GetNumBonds() set_atom_map_num(mol) detect_functional_group(mol) - node_features = [] - sorted_atoms = sorted( - list(mol.GetAtoms()), key=lambda atom: atom.GetAtomMapNum() - ) + for atom in mol.GetAtoms(): + atom.SetProp(NODE_LEVEL, ATOM_NODE_LEVEL) - for idx, atom in enumerate(sorted_atoms): - node_features.append( - [ - self.NODE_LEVEL["atom_node"], - self._get_fg_index(atom), - self._get_ring_size(atom), - ] - ) + for edge in mol.GetBonds(): + edge.SetProp(EDGE_LEVEL, WITHIN_ATOMS_EDGE) structure, bonds = get_structure(mol) @@ -235,6 +235,7 @@ def _augment_graph(self, mol: Mol): # Preprocess the molecular structure to match feature dictionary keys fg_to_atoms_edge_index = [[], []] + fg_nodes, fg_atom_edges = {}, {} new_structure = {} for idx, fg in enumerate(structure): # new_sm = preprocess_smiles(sm) # Preprocess SMILES to match the feature dictionary @@ -244,18 +245,19 @@ def _augment_graph(self, mol: Mol): for atom in structure[fg]["atom"]: fg_to_atoms_edge_index[0].extend([num_of_nodes, atom]) fg_to_atoms_edge_index[1].extend([atom, num_of_nodes]) + fg_atom_edges[f"{num_of_nodes}_{atom}"] = {EDGE_LEVEL: ATOM_FG_EDGE} + num_of_edges += 1 any_atom = next(iter(structure[fg]["atom"][0])) # any atom related to fg - node_features.append( - [ - self.NODE_LEVEL["fg_node"], - self._get_fg_index(any_atom), - self._get_ring_size(any_atom), - ] - ) + fg_nodes[num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": any_atom.GetProp("FG"), + "RING": any_atom.GetProp("RING"), + } num_of_nodes += 1 + fg_edges = {} within_fg_edge_index = [[], []] for bond in bonds: start_idx, end_idx = bond[:2] @@ -266,14 +268,24 @@ def _augment_graph(self, mol: Mol): target_fg = key within_fg_edge_index[0].extend([source_fg, target_fg]) within_fg_edge_index[1].extend([target_fg, source_fg]) + fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} + num_of_edges += 1 - node_features.append( - [self.NODE_LEVEL["global_node"], self._get_token_index("graph_fg"), 0] - ) + graph_node = { + NODE_LEVEL: GRAPH_NODE_LEVEL, + "FG": "graph_fg", + "RING": "0", + } + + fg_graphNode_edges = {} global_node_edge_index = [[], []] for fg in new_structure.keys(): global_node_edge_index[0].extend([num_of_nodes, fg]) global_node_edge_index[1].extend([fg, num_of_nodes]) + fg_graphNode_edges[f"{num_of_nodes}_{fg}"] = { + NODE_LEVEL: FG_GRAPHNODE_LEVEL + } + num_of_edges += 1 all_edges = torch.cat( [ @@ -285,7 +297,21 @@ def _augment_graph(self, mol: Mol): dim=1, ) - return all_edges + augmented_graph_nodes = { + "atom_nodes": mol, + "fg_nodes": fg_nodes, + "graph_node": graph_node, + "num_nodes": num_of_nodes, + } + augmented_graph_edges = { + WITHIN_ATOMS_EDGE: mol, + WITHIN_FG_EDGE: fg_edges, + ATOM_FG_EDGE: fg_atom_edges, + FG_GRAPHNODE_LEVEL: fg_graphNode_edges, + "num_edges": num_of_edges, + } + + return all_edges, augmented_graph_nodes, augmented_graph_edges def _get_fg_index(self, atom): fg_group = atom.GetProp("FG") @@ -314,4 +340,9 @@ def read_property( mol = self._smiles_to_mol(smiles) if mol is None: return None + + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + + augmented_mol, _ = self._get_augmented_molecule(smiles) return property.get_property_value(mol) From 5c619861cc615051944dee9cdeb698b525e7b4db Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 21 Apr 2025 16:53:57 +0200 Subject: [PATCH 009/224] move fg constants to seperate file --- .../fg_detection/fg_constants.py | 21 ++ .../preprocessing/fg_detection/rule_based.py | 203 +----------------- 2 files changed, 22 insertions(+), 202 deletions(-) create mode 100644 chebai_graph/preprocessing/fg_detection/fg_constants.py diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py new file mode 100644 index 0000000..f96a125 --- /dev/null +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -0,0 +1,21 @@ +ATOM_NODE_LEVEL = "atom_node_lvl" +FG_NODE_LEVEL = "fg_node_lvl" +GRAPH_NODE_LEVEL = "graph_node_level" +NODE_LEVEL = "node_level" +NODE_LEVELS = {ATOM_NODE_LEVEL, FG_NODE_LEVEL, GRAPH_NODE_LEVEL} + +EDGE_LEVEL = "edge_level" +WITHIN_ATOMS_EDGE = "within_atoms_lvl" +WITHIN_FG_EDGE = "within_fg_lvl" +ATOM_FG_EDGE = "atom_fg_lvl" +FG_GRAPHNODE_LEVEL = "fg_graphNode_lvl" +EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_LEVEL} + +# fmt: off +ELEMENTS = {"Ac", "Ag", "Al", "Am", "As", "At", "Au", "B", "Ba", "Be", "Bi", "Bk", "Br", "Ca", "Cd", "Ce", "Cf", "Cl", + "Cm", "Co", "Cr", "Cs", "Cu", "Dy", "Er", "Es", "Eu", "F", "Fe", "Fm", "Fr", "Ga", "Gd", "Ge", "He", "Hf", + "Hg", "Ho", "I", "In", "Ir", "K", "Kr", "La", "Li", "Lr", "Lu", "Md", "Mg", "Mn", "Mo", "N", "Na", "Nb", + "Nd", "Ne", "Ni", "Np", "O", "Os", "P", "Pa", "Pb", "Pd", "Pm", "Po", "Pr", "Pt", "Pu", "Ra", "Rb", "Re", + "Rh", "Rn", "Ru", "S", "Sb", "Sc", "Se", "Si", "Sm", "Sn", "Sr", "Ta", "Tb", "Tc", "Te", "Th", "Ti", "Tl", + "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr"} +# fmt: on diff --git a/chebai_graph/preprocessing/fg_detection/rule_based.py b/chebai_graph/preprocessing/fg_detection/rule_based.py index 3a1083a..ab4fe30 100644 --- a/chebai_graph/preprocessing/fg_detection/rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/rule_based.py @@ -8,105 +8,7 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s -electronegativity = { - "H": 2.2, - "LI": 0.98, - "BE": 1.57, - "B": 2.04, - "C": 2.55, - "N": 3.04, - "O": 3.44, - "F": 3.98, - "NA": 0.93, - "MG": 1.31, - "AL": 1.61, - "SI": 1.9, - "P": 2.19, - "S": 2.58, - "CL": 3.16, - "K": 0.82, - "CA": 1.0, - "SC": 1.36, - "TI": 1.54, - "V": 1.63, - "CR": 1.66, - "MN": 1.55, - "FE": 1.83, - "CO": 1.88, - "NI": 1.91, - "CU": 1.9, - "ZN": 1.65, - "GA": 1.81, - "GE": 2.01, - "AS": 2.18, - "SE": 2.55, - "BR": 2.96, - "RB": 0.82, - "SR": 0.95, - "Y": 1.22, - "ZR": 1.33, - "NB": 1.6, - "MO": 2.16, - "TC": 1.9, - "RU": 2.2, - "RH": 2.28, - "PD": 2.2, - "AG": 1.93, - "CD": 1.69, - "IN": 1.78, - "SN": 1.96, - "SB": 2.05, - "TE": 2.1, - "I": 2.66, - "CS": 0.79, - "BA": 0.89, - "LA": 1.1, - "CE": 1.12, - "PR": 1.13, - "ND": 1.14, - "PM": 1.13, - "SM": 1.17, - "EU": 1.2, - "GD": 1.2, - "TB": 1.1, - "DY": 1.22, - "HO": 1.23, - "ER": 1.24, - "TM": 1.25, - "YB": 1.1, - "LU": 1.27, - "HF": 1.3, - "TA": 1.5, - "W": 2.36, - "RE": 1.9, - "OS": 2.2, - "IR": 2.2, - "PT": 2.28, - "AU": 2.54, - "HG": 2.0, - "TL": 1.62, - "PB": 2.33, - "BI": 2.02, - "PO": 2.0, - "AT": 2.2, - "FR": 0.7, - "RA": 0.9, - "AC": 1.1, - "TH": 1.3, - "PA": 1.5, - "U": 1.38, - "NP": 1.36, - "PU": 1.28, - "AM": 1.3, - "CM": 1.3, - "BK": 1.3, - "CF": 1.3, - "ES": 1.3, - "FM": 1.3, - "MD": 1.3, - "NO": 1.3, - "LR": 1.3, -} +from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS def ring_size_processing(ring_size): @@ -133,109 +35,6 @@ def find_connected_rings(ring, remaining_rings): def detect_functional_group(mol): # type: ignore AllChem.GetSymmSSSR(mol) # type: ignore - ELEMENTS = set( - [ - "Ac", - "Ag", - "Al", - "Am", - "As", - "At", - "Au", - "B", - "Ba", - "Be", - "Bi", - "Bk", - "Br", - "Ca", - "Cd", - "Ce", - "Cf", - "Cl", - "Cm", - "Co", - "Cr", - "Cs", - "Cu", - "Dy", - "Er", - "Es", - "Eu", - "F", - "Fe", - "Fm", - "Fr", - "Ga", - "Gd", - "Ge", - "He", - "Hf", - "Hg", - "Ho", - "I", - "In", - "Ir", - "K", - "Kr", - "La", - "Li", - "Lr", - "Lu", - "Md", - "Mg", - "Mn", - "Mo", - "N", - "Na", - "Nb", - "Nd", - "Ne", - "Ni", - "Np", - "O", - "Os", - "P", - "Pa", - "Pb", - "Pd", - "Pm", - "Po", - "Pr", - "Pt", - "Pu", - "Ra", - "Rb", - "Re", - "Rh", - "Rn", - "Ru", - "S", - "Sb", - "Sc", - "Se", - "Si", - "Sm", - "Sn", - "Sr", - "Ta", - "Tb", - "Tc", - "Te", - "Th", - "Ti", - "Tl", - "Tm", - "U", - "V", - "W", - "Xe", - "Y", - "Yb", - "Zn", - "Zr", - ] - ) if mol is not None: for atom in mol.GetAtoms(): From 11ff800b1394b20b7ded18b8011cf4fc1058cf7b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 21 Apr 2025 16:56:11 +0200 Subject: [PATCH 010/224] Elements: re-align --- .../preprocessing/fg_detection/fg_constants.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py index f96a125..24f315f 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_constants.py +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -12,10 +12,13 @@ EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_LEVEL} # fmt: off -ELEMENTS = {"Ac", "Ag", "Al", "Am", "As", "At", "Au", "B", "Ba", "Be", "Bi", "Bk", "Br", "Ca", "Cd", "Ce", "Cf", "Cl", - "Cm", "Co", "Cr", "Cs", "Cu", "Dy", "Er", "Es", "Eu", "F", "Fe", "Fm", "Fr", "Ga", "Gd", "Ge", "He", "Hf", - "Hg", "Ho", "I", "In", "Ir", "K", "Kr", "La", "Li", "Lr", "Lu", "Md", "Mg", "Mn", "Mo", "N", "Na", "Nb", - "Nd", "Ne", "Ni", "Np", "O", "Os", "P", "Pa", "Pb", "Pd", "Pm", "Po", "Pr", "Pt", "Pu", "Ra", "Rb", "Re", - "Rh", "Rn", "Ru", "S", "Sb", "Sc", "Se", "Si", "Sm", "Sn", "Sr", "Ta", "Tb", "Tc", "Te", "Th", "Ti", "Tl", - "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr"} +ELEMENTS = { + "Ac", "Ag", "Al", "Am", "As", "At", "Au", "B", "Ba", "Be", "Bi", "Bk", "Br", "Ca", + "Cd", "Ce", "Cf", "Cl", "Cm", "Co", "Cr", "Cs", "Cu", "Dy", "Er", "Es", "Eu", "F", + "Fe", "Fm", "Fr", "Ga", "Gd", "Ge", "He", "Hf", "Hg", "Ho", "I", "In", "Ir", "K", + "Kr", "La", "Li", "Lr", "Lu", "Md", "Mg", "Mn", "Mo", "N", "Na", "Nb", "Nd", "Ne", + "Ni", "Np", "O", "Os", "P", "Pa", "Pb", "Pd", "Pm", "Po", "Pr", "Pt", "Pu", "Ra", + "Rb", "Re", "Rh", "Rn", "Ru", "S", "Sb", "Sc", "Se", "Si", "Sm", "Sn", "Sr", "Ta", + "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr" +} # fmt: on From 4162caae2f36602e2bccbd14397165c3bb6fd681 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 22 Apr 2025 14:53:33 +0200 Subject: [PATCH 011/224] properties: add base class for augmented bond property --- chebai_graph/preprocessing/properties.py | 46 ++++++++++++++---------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 15f2e61..e77198e 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -7,9 +7,12 @@ from rdkit.Chem import Mol from chebai_graph.preprocessing.fg_detection.fg_constants import ( + ATOM_FG_EDGE, EDGE_LEVELS, + FG_GRAPHNODE_LEVEL, NODE_LEVEL, WITHIN_ATOMS_EDGE, + WITHIN_FG_EDGE, ) from chebai_graph.preprocessing.property_encoder import ( AsIsEncoder, @@ -136,46 +139,53 @@ def get_property_value(self, augmented_mol: Dict): atom_molecule: Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] if not isinstance(atom_molecule, Mol): raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' ) - prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] - fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] - graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] - if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): + fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] + fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] + fg_graphNode_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_LEVEL] + + if ( + not isinstance(fg_atom_edges, dict) + or not isinstance(fg_edges, dict) + or not isinstance(fg_graphNode_edges, dict) + ): raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' - f"containing its properties" + f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_LEVEL}"]) ' + f"must be an instance of dict containing its properties" ) # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) - prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_graphNode_edges]) return prop_list @abc.abstractmethod - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): pass - def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): - value = self._get_atom_prop_value(atom, prop) + def _check_modify_atom_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): + value = self._get_bond_prop_value(bond, prop) if not value: # Every atom/node should have given value raise ValueError(f"'{prop}' is set but empty.") return value @staticmethod - def _get_atom_prop_value(atom: Chem.rdchem.Atom | Dict, prop: str): - if isinstance(atom, Chem.rdchem.Atom): - return atom.GetProp(prop) - elif isinstance(atom, dict): - return atom[prop] + def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): + if isinstance(bond, Chem.rdchem.Bond): + return bond.GetProp(prop) + elif isinstance(bond, dict): + return bond[prop] else: - raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") class MoleculeProperty(MolecularProperty): From 601d4c466640b1984ed6dec0482d8061ea46ea9c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 22 Apr 2025 21:36:30 +0200 Subject: [PATCH 012/224] move augmented properties to utils --- .../preprocessing/fg_detection/fg_constants.py | 13 ------------- chebai_graph/preprocessing/utils/__init__.py | 0 .../preprocessing/utils/properties_constants.py | 12 ++++++++++++ 3 files changed, 12 insertions(+), 13 deletions(-) create mode 100644 chebai_graph/preprocessing/utils/__init__.py create mode 100644 chebai_graph/preprocessing/utils/properties_constants.py diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py index 24f315f..9b71e9f 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_constants.py +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -1,16 +1,3 @@ -ATOM_NODE_LEVEL = "atom_node_lvl" -FG_NODE_LEVEL = "fg_node_lvl" -GRAPH_NODE_LEVEL = "graph_node_level" -NODE_LEVEL = "node_level" -NODE_LEVELS = {ATOM_NODE_LEVEL, FG_NODE_LEVEL, GRAPH_NODE_LEVEL} - -EDGE_LEVEL = "edge_level" -WITHIN_ATOMS_EDGE = "within_atoms_lvl" -WITHIN_FG_EDGE = "within_fg_lvl" -ATOM_FG_EDGE = "atom_fg_lvl" -FG_GRAPHNODE_LEVEL = "fg_graphNode_lvl" -EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_LEVEL} - # fmt: off ELEMENTS = { "Ac", "Ag", "Al", "Am", "As", "At", "Au", "B", "Ba", "Be", "Bi", "Bk", "Br", "Ca", diff --git a/chebai_graph/preprocessing/utils/__init__.py b/chebai_graph/preprocessing/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai_graph/preprocessing/utils/properties_constants.py b/chebai_graph/preprocessing/utils/properties_constants.py new file mode 100644 index 0000000..fef6cfd --- /dev/null +++ b/chebai_graph/preprocessing/utils/properties_constants.py @@ -0,0 +1,12 @@ +ATOM_NODE_LEVEL = "atom_node_lvl" +FG_NODE_LEVEL = "fg_node_lvl" +GRAPH_NODE_LEVEL = "graph_node_level" +NODE_LEVEL = "node_level" +NODE_LEVELS = {ATOM_NODE_LEVEL, FG_NODE_LEVEL, GRAPH_NODE_LEVEL} + +EDGE_LEVEL = "edge_level" +WITHIN_ATOMS_EDGE = "within_atoms_lvl" +WITHIN_FG_EDGE = "within_fg_lvl" +ATOM_FG_EDGE = "atom_fg_lvl" +FG_GRAPHNODE_LEVEL = "fg_graphNode_lvl" +EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_LEVEL} From 51e4676ef2c19dfe3c99818371710dd16113ad9e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 22 Apr 2025 21:37:01 +0200 Subject: [PATCH 013/224] add bond level prop for augmented bond --- chebai_graph/preprocessing/properties.py | 19 ++++++++++--------- chebai_graph/preprocessing/reader.py | 14 ++------------ 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index e77198e..c969d66 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,14 +6,6 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from rdkit.Chem import Mol -from chebai_graph.preprocessing.fg_detection.fg_constants import ( - ATOM_FG_EDGE, - EDGE_LEVELS, - FG_GRAPHNODE_LEVEL, - NODE_LEVEL, - WITHIN_ATOMS_EDGE, - WITHIN_FG_EDGE, -) from chebai_graph.preprocessing.property_encoder import ( AsIsEncoder, BoolEncoder, @@ -21,6 +13,7 @@ OneHotEncoder, PropertyEncoder, ) +from chebai_graph.preprocessing.utils.properties_constants import * class MolecularProperty(abc.ABC): @@ -171,7 +164,7 @@ def get_property_value(self, augmented_mol: Dict): def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): pass - def _check_modify_atom_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): + def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): value = self._get_bond_prop_value(bond, prop) if not value: # Every atom/node should have given value @@ -305,6 +298,14 @@ def get_bond_value(self, bond: Chem.rdchem.Bond): return bond.IsInRing() +class BondLevel(AugmentedBondProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): + return self._check_modify_bond_prop_value(bond, EDGE_LEVEL) + + class MoleculeNumRings(MolecularProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 5e915f0..cd103a9 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -13,22 +13,12 @@ import chebai_graph.preprocessing.properties as properties from chebai_graph.preprocessing.collate import GraphCollator -from chebai_graph.preprocessing.fg_detection.fg_constants import ( - ATOM_FG_EDGE, - ATOM_NODE_LEVEL, - EDGE_LEVEL, - FG_GRAPHNODE_LEVEL, - FG_NODE_LEVEL, - GRAPH_NODE_LEVEL, - NODE_LEVEL, - WITHIN_ATOMS_EDGE, - WITHIN_FG_EDGE, -) from chebai_graph.preprocessing.fg_detection.rule_based import ( detect_functional_group, get_structure, set_atom_map_num, ) +from chebai_graph.preprocessing.utils.properties_constants import * class GraphPropertyReader(dr.ChemDataReader): @@ -305,8 +295,8 @@ def _augment_graph(self, mol: Mol): } augmented_graph_edges = { WITHIN_ATOMS_EDGE: mol, - WITHIN_FG_EDGE: fg_edges, ATOM_FG_EDGE: fg_atom_edges, + WITHIN_FG_EDGE: fg_edges, FG_GRAPHNODE_LEVEL: fg_graphNode_edges, "num_edges": num_of_edges, } From 8a2f8946e634297fd1369b7788cf85af027dd37a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 28 Apr 2025 23:37:14 +0200 Subject: [PATCH 014/224] add fg reader --- chebai_graph/preprocessing/reader.py | 33 ++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index cd103a9..e3823e2 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -336,3 +336,36 @@ def read_property( augmented_mol, _ = self._get_augmented_molecule(smiles) return property.get_property_value(mol) + + +class RuleBasedFGReader(dr.ChemDataReader): + + @classmethod + def name(cls) -> str: + return "rule_based_fg" + + def _read_data(self, augmented_mol: dict) -> List[int] | None: + feature_vector = [] + augmented_mol_nodes = augmented_mol["nodes"] + + if "atom_nodes" in augmented_mol_nodes: + if not isinstance(augmented_mol_nodes["atom_nodes"], Chem.Mol): + raise TypeError(f"augmented_mol_nodes['atom_nodes'] should be Chem.Mol") + feature_vector.extend( + self._get_token_index(node.GetProp("FG")) + for node in augmented_mol_nodes["atom_nodes"] + ) + + if "fg_nodes" in augmented_mol_nodes: + feature_vector.extend( + self._get_token_index(node["FG"]) + for node in augmented_mol_nodes["fg_nodes"] + ) + + if "graph_node" in augmented_mol_nodes: + feature_vector.extend( + self._get_token_index(node["FG"]) + for node in augmented_mol_nodes["graph_node"] + ) + + return feature_vector if feature_vector else None From c102fee052f68197b74ab0afee2e4910e6d69d98 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 16:18:28 +0200 Subject: [PATCH 015/224] AtomFG prop reads from fgreader --- chebai_graph/preprocessing/properties.py | 17 ++++++++++++--- chebai_graph/preprocessing/reader.py | 27 ++---------------------- 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index c969d66..d163774 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -13,6 +13,7 @@ OneHotEncoder, PropertyEncoder, ) +from chebai_graph.preprocessing.reader import RuleBasedFGReader from chebai_graph.preprocessing.utils.properties_constants import * @@ -98,14 +99,15 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str raise ValueError(f"'{prop}' is set but empty.") return value - @staticmethod - def _get_atom_prop_value(atom: Chem.rdchem.Atom | Dict, prop: str): + def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): if isinstance(atom, Chem.rdchem.Atom): return atom.GetProp(prop) elif isinstance(atom, dict): return atom[prop] else: - raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") + raise TypeError( + f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." + ) class BondProperty(MolecularProperty): @@ -252,10 +254,19 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): class AtomFunctionalGroup(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) + self.fg_reader = RuleBasedFGReader() def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, "FG") + def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + if isinstance(atom, Chem.rdchem.Atom): + return self.fg_reader._read_data(atom.GetProp(prop)) # noqa + elif isinstance(atom, dict): + return self.fg_reader._read_data(atom[prop]) # noqa + else: + raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") + class AtomRingSize(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index e3823e2..da18c24 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -344,28 +344,5 @@ class RuleBasedFGReader(dr.ChemDataReader): def name(cls) -> str: return "rule_based_fg" - def _read_data(self, augmented_mol: dict) -> List[int] | None: - feature_vector = [] - augmented_mol_nodes = augmented_mol["nodes"] - - if "atom_nodes" in augmented_mol_nodes: - if not isinstance(augmented_mol_nodes["atom_nodes"], Chem.Mol): - raise TypeError(f"augmented_mol_nodes['atom_nodes'] should be Chem.Mol") - feature_vector.extend( - self._get_token_index(node.GetProp("FG")) - for node in augmented_mol_nodes["atom_nodes"] - ) - - if "fg_nodes" in augmented_mol_nodes: - feature_vector.extend( - self._get_token_index(node["FG"]) - for node in augmented_mol_nodes["fg_nodes"] - ) - - if "graph_node" in augmented_mol_nodes: - feature_vector.extend( - self._get_token_index(node["FG"]) - for node in augmented_mol_nodes["graph_node"] - ) - - return feature_vector if feature_vector else None + def _read_data(self, fg: str) -> int | None: + return self._get_token_index(fg) From f9c959bf0a338fcb6e33789693df75676b26e97a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 16:31:28 +0200 Subject: [PATCH 016/224] move properties to properties dir --- .../preprocessing/properties/__init__.py | 2 + .../augmented_properties.py} | 262 ++++-------------- .../preprocessing/properties/properties.py | 159 +++++++++++ 3 files changed, 215 insertions(+), 208 deletions(-) create mode 100644 chebai_graph/preprocessing/properties/__init__.py rename chebai_graph/preprocessing/{properties.py => properties/augmented_properties.py} (59%) create mode 100644 chebai_graph/preprocessing/properties/properties.py diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py new file mode 100644 index 0000000..525d248 --- /dev/null +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -0,0 +1,2 @@ +from .augmented_properties import * +from .properties import * diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py similarity index 59% rename from chebai_graph/preprocessing/properties.py rename to chebai_graph/preprocessing/properties/augmented_properties.py index d163774..90f4fd8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,123 +1,13 @@ import abc from typing import Dict, Optional -import numpy as np -import rdkit.Chem as Chem -from descriptastorus.descriptors import rdNormalizedDescriptors -from rdkit.Chem import Mol - -from chebai_graph.preprocessing.property_encoder import ( - AsIsEncoder, - BoolEncoder, - IndexEncoder, - OneHotEncoder, - PropertyEncoder, -) +from rdkit import Chem + +from chebai_graph.preprocessing import MolecularProperty, OneHotEncoder, PropertyEncoder from chebai_graph.preprocessing.reader import RuleBasedFGReader from chebai_graph.preprocessing.utils.properties_constants import * -class MolecularProperty(abc.ABC): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - if encoder is None: - encoder = IndexEncoder(self) - self.encoder = encoder - - @property - def name(self): - """Unique identifier for this property.""" - return self.__class__.__name__ - - def on_finish(self): - """Called after dataset processing is done.""" - self.encoder.on_finish() - - def __str__(self): - return self.name - - def get_property_value(self, mol: Chem.rdchem.Mol): - raise NotImplementedError - - -class AtomProperty(MolecularProperty, abc): - """Property of an atom.""" - - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_atom_value(atom) for atom in mol.GetAtoms()] - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return NotImplementedError - - -class AugmentedAtomProperty(MolecularProperty, abc): - MAIN_KEY = "nodes" - - def get_property_value(self, augmented_mol: Dict): - if self.MAIN_KEY not in augmented_mol: - raise KeyError( - f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" - ) - - missing_keys = {"atom_nodes", "fg_nodes", "graph_node"} - augmented_mol[ - self.MAIN_KEY - ].keys() - if missing_keys: - raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - - atom_molecule: Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] - if not isinstance(atom_molecule, Mol): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' - ) - - prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] - - fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] - graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] - if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' - f"containing its properties" - ) - - # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order - # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights - # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) - prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) - - return prop_list - - @abc.abstractmethod - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): - pass - - def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): - value = self._get_atom_prop_value(atom, prop) - if not value: - # Every atom/node should have given value - raise ValueError(f"'{prop}' is set but empty.") - return value - - def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): - if isinstance(atom, Chem.rdchem.Atom): - return atom.GetProp(prop) - elif isinstance(atom, dict): - return atom[prop] - else: - raise TypeError( - f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." - ) - - -class BondProperty(MolecularProperty): - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_bond_value(bond) for bond in mol.GetBonds()] - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return NotImplementedError - - class AugmentedBondProperty(MolecularProperty, abc): MAIN_KEY = "edges" @@ -131,8 +21,8 @@ def get_property_value(self, augmented_mol: Dict): if missing_keys: raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - atom_molecule: Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] - if not isinstance(atom_molecule, Mol): + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] + if not isinstance(atom_molecule, Chem.Mol): raise TypeError( f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' ) @@ -183,64 +73,65 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") -class MoleculeProperty(MolecularProperty): - """Global property of a molecule.""" - - -class AtomType(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetAtomicNum() - - -class NumAtomBonds(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetDegree() - - -class AtomCharge(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetFormalCharge() - - -class AtomChirality(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) +class AugmentedAtomProperty(MolecularProperty, abc): + MAIN_KEY = "nodes" - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetChiralTag() + def get_property_value(self, augmented_mol: Dict): + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + missing_keys = {"atom_nodes", "fg_nodes", "graph_node"} - augmented_mol[ + self.MAIN_KEY + ].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") -class AtomHybridization(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + ) - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetHybridization() + prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] + if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) -class AtomNumHs(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) + prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetTotalNumHs() + return prop_list + @abc.abstractmethod + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + pass -class AtomAromaticity(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) + def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + value = self._get_atom_prop_value(atom, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value - def get_atom_value(self, atom: Chem.rdchem.Atom): - return atom.GetIsAromatic() + def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + if isinstance(atom, Chem.rdchem.Atom): + return atom.GetProp(prop) + elif isinstance(atom, dict): + return atom[prop] + else: + raise TypeError( + f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." + ) class AtomNodeLevel(AugmentedAtomProperty): @@ -285,54 +176,9 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str return 0 -class BondAromaticity(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.GetIsAromatic() - - -class BondType(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.GetBondType() - - -class BondInRing(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or BoolEncoder(self)) - - def get_bond_value(self, bond: Chem.rdchem.Bond): - return bond.IsInRing() - - class BondLevel(AugmentedBondProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): return self._check_modify_bond_prop_value(bond, EDGE_LEVEL) - - -class MoleculeNumRings(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or OneHotEncoder(self)) - - def get_property_value(self, mol: Chem.rdchem.Mol): - return [mol.GetRingInfo().NumRings()] - - -class RDKit2DNormalized(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - super().__init__(encoder or AsIsEncoder(self)) - - def get_property_value(self, mol: Chem.rdchem.Mol): - generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() - features_normalized = generator_normalized.processMol( - mol, Chem.MolToSmiles(mol) - ) - np.nan_to_num(features_normalized) - return [features_normalized[1:]] diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py new file mode 100644 index 0000000..5ee942b --- /dev/null +++ b/chebai_graph/preprocessing/properties/properties.py @@ -0,0 +1,159 @@ +import abc +from typing import Optional + +import numpy as np +import rdkit.Chem as Chem +from descriptastorus.descriptors import rdNormalizedDescriptors + +from chebai_graph.preprocessing.property_encoder import ( + AsIsEncoder, + BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, +) + + +class MolecularProperty(abc.ABC): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + if encoder is None: + encoder = IndexEncoder(self) + self.encoder = encoder + + @property + def name(self): + """Unique identifier for this property.""" + return self.__class__.__name__ + + def on_finish(self): + """Called after dataset processing is done.""" + self.encoder.on_finish() + + def __str__(self): + return self.name + + def get_property_value(self, mol: Chem.rdchem.Mol): + raise NotImplementedError + + +class AtomProperty(MolecularProperty, abc): + """Property of an atom.""" + + def get_property_value(self, mol: Chem.rdchem.Mol): + return [self.get_atom_value(atom) for atom in mol.GetAtoms()] + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return NotImplementedError + + +class BondProperty(MolecularProperty): + def get_property_value(self, mol: Chem.rdchem.Mol): + return [self.get_bond_value(bond) for bond in mol.GetBonds()] + + def get_bond_value(self, bond: Chem.rdchem.Bond): + return NotImplementedError + + +class MoleculeProperty(MolecularProperty): + """Global property of a molecule.""" + + +class AtomType(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetAtomicNum() + + +class NumAtomBonds(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetDegree() + + +class AtomCharge(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetFormalCharge() + + +class AtomChirality(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetChiralTag() + + +class AtomHybridization(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetHybridization() + + +class AtomNumHs(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetTotalNumHs() + + +class AtomAromaticity(AtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom): + return atom.GetIsAromatic() + + +class BondAromaticity(BondProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond): + return bond.GetIsAromatic() + + +class BondType(BondProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond): + return bond.GetBondType() + + +class BondInRing(BondProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + + def get_bond_value(self, bond: Chem.rdchem.Bond): + return bond.IsInRing() + + +class MoleculeNumRings(MolecularProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or OneHotEncoder(self)) + + def get_property_value(self, mol: Chem.rdchem.Mol): + return [mol.GetRingInfo().NumRings()] + + +class RDKit2DNormalized(MolecularProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or AsIsEncoder(self)) + + def get_property_value(self, mol: Chem.rdchem.Mol): + generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() + features_normalized = generator_normalized.processMol( + mol, Chem.MolToSmiles(mol) + ) + np.nan_to_num(features_normalized) + return [features_normalized[1:]] From c3448e31eb1d8343aeae7611bd6d7a621f2f3583 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 16:33:45 +0200 Subject: [PATCH 017/224] data config for augmented graph --- configs/data/chebi50_augmented_gnn.yml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 configs/data/chebi50_augmented_gnn.yml diff --git a/configs/data/chebi50_augmented_gnn.yml b/configs/data/chebi50_augmented_gnn.yml new file mode 100644 index 0000000..c748ac2 --- /dev/null +++ b/configs/data/chebi50_augmented_gnn.yml @@ -0,0 +1,7 @@ +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphProperties +init_args: + properties: + - chebai_graph.preprocessing.properties.AtomRingSize + - chebai_graph.preprocessing.properties.AtomNodeLevel + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.BondLevel From 38de6f38f28e7d434f21de1328cd0ee3e043b276 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 16:35:52 +0200 Subject: [PATCH 018/224] right molecular prop import --- chebai_graph/preprocessing/properties/augmented_properties.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 90f4fd8..e1e838c 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -3,7 +3,8 @@ from rdkit import Chem -from chebai_graph.preprocessing import MolecularProperty, OneHotEncoder, PropertyEncoder +from chebai_graph.preprocessing import OneHotEncoder, PropertyEncoder +from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.reader import RuleBasedFGReader from chebai_graph.preprocessing.utils.properties_constants import * From 6ff16091794418fdf249cab654993d8bc3ac6ca0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 17:07:37 +0200 Subject: [PATCH 019/224] move prop constants to props dir --- chebai_graph/preprocessing/properties/augmented_properties.py | 2 +- .../{utils/properties_constants.py => properties/constants.py} | 0 chebai_graph/preprocessing/reader.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) rename chebai_graph/preprocessing/{utils/properties_constants.py => properties/constants.py} (100%) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index e1e838c..4e57732 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -5,8 +5,8 @@ from chebai_graph.preprocessing import OneHotEncoder, PropertyEncoder from chebai_graph.preprocessing.properties import MolecularProperty +from chebai_graph.preprocessing.properties.constants import * from chebai_graph.preprocessing.reader import RuleBasedFGReader -from chebai_graph.preprocessing.utils.properties_constants import * class AugmentedBondProperty(MolecularProperty, abc): diff --git a/chebai_graph/preprocessing/utils/properties_constants.py b/chebai_graph/preprocessing/properties/constants.py similarity index 100% rename from chebai_graph/preprocessing/utils/properties_constants.py rename to chebai_graph/preprocessing/properties/constants.py diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index da18c24..05585b9 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -18,7 +18,7 @@ get_structure, set_atom_map_num, ) -from chebai_graph.preprocessing.utils.properties_constants import * +from chebai_graph.preprocessing.properties.constants import * class GraphPropertyReader(dr.ChemDataReader): @@ -249,6 +249,7 @@ def _augment_graph(self, mol: Mol): fg_edges = {} within_fg_edge_index = [[], []] + # TODO: Can we optimize this ? for bond in bonds: start_idx, end_idx = bond[:2] for key, value in new_structure.items(): From 43fe5dc15219d8ca3f9ac5d0cc29cc18d398ad92 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 17:07:57 +0200 Subject: [PATCH 020/224] make data dir a python dir --- chebai_graph/preprocessing/datasets/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 chebai_graph/preprocessing/datasets/__init__.py diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py new file mode 100644 index 0000000..e69de29 From 45e15b72f0881490dbb9097b1e8d5127dadba662 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 17:08:10 +0200 Subject: [PATCH 021/224] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index f9cb175..9676c5b 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ cython_debug/ /logs /results_buffer electra_pretrained.ckpt +.isort.cfg From 4e8c5fd1a3c3f2d91772beebac2a3f6dbd3a43d5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 18:20:00 +0200 Subject: [PATCH 022/224] remove properties imports in preprocessing init --- chebai_graph/preprocessing/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai_graph/preprocessing/__init__.py b/chebai_graph/preprocessing/__init__.py index 2b98ba8..e69de29 100644 --- a/chebai_graph/preprocessing/__init__.py +++ b/chebai_graph/preprocessing/__init__.py @@ -1 +0,0 @@ -from chebai_graph.preprocessing.properties import * From c0068870360c69551af3fa183b7397548fb85b3f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 18:33:23 +0200 Subject: [PATCH 023/224] resolve metaclass error --- .../preprocessing/properties/properties.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index 5ee942b..06693fe 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -1,4 +1,4 @@ -import abc +from abc import ABC, abstractmethod from typing import Optional import numpy as np @@ -14,7 +14,7 @@ ) -class MolecularProperty(abc.ABC): +class MolecularProperty(ABC): def __init__(self, encoder: Optional[PropertyEncoder] = None): if encoder is None: encoder = IndexEncoder(self) @@ -36,22 +36,24 @@ def get_property_value(self, mol: Chem.rdchem.Mol): raise NotImplementedError -class AtomProperty(MolecularProperty, abc): +class AtomProperty(MolecularProperty, ABC): """Property of an atom.""" def get_property_value(self, mol: Chem.rdchem.Mol): return [self.get_atom_value(atom) for atom in mol.GetAtoms()] + @abstractmethod def get_atom_value(self, atom: Chem.rdchem.Atom): - return NotImplementedError + pass -class BondProperty(MolecularProperty): +class BondProperty(MolecularProperty, ABC): def get_property_value(self, mol: Chem.rdchem.Mol): return [self.get_bond_value(bond) for bond in mol.GetBonds()] + @abstractmethod def get_bond_value(self, bond: Chem.rdchem.Bond): - return NotImplementedError + pass class MoleculeProperty(MolecularProperty): From 501489c67349ef14c12564ce80143398e3165206 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 18:49:30 +0200 Subject: [PATCH 024/224] move rule based reader to class init to avoid circular import --- .../properties/augmented_properties.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 4e57732..a1e870d 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,15 +1,14 @@ -import abc +from abc import ABC, abstractmethod from typing import Dict, Optional from rdkit import Chem -from chebai_graph.preprocessing import OneHotEncoder, PropertyEncoder from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties.constants import * -from chebai_graph.preprocessing.reader import RuleBasedFGReader +from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder -class AugmentedBondProperty(MolecularProperty, abc): +class AugmentedBondProperty(MolecularProperty, ABC): MAIN_KEY = "edges" def get_property_value(self, augmented_mol: Dict): @@ -32,12 +31,12 @@ def get_property_value(self, augmented_mol: Dict): fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] - fg_graphNode_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_LEVEL] + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_LEVEL] if ( not isinstance(fg_atom_edges, dict) or not isinstance(fg_edges, dict) - or not isinstance(fg_graphNode_edges, dict) + or not isinstance(fg_graph_node_edges, dict) ): raise TypeError( f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_LEVEL}"]) ' @@ -49,11 +48,11 @@ def get_property_value(self, augmented_mol: Dict): # https://mail.python.org/pipermail/python-dev/2017-December/151283.html prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges]) prop_list.extend([self.get_bond_value(bond) for bond in fg_edges]) - prop_list.extend([self.get_bond_value(bond) for bond in fg_graphNode_edges]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_graph_node_edges]) return prop_list - @abc.abstractmethod + @abstractmethod def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): pass @@ -74,7 +73,7 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") -class AugmentedAtomProperty(MolecularProperty, abc): +class AugmentedAtomProperty(MolecularProperty, ABC): MAIN_KEY = "nodes" def get_property_value(self, augmented_mol: Dict): @@ -113,7 +112,7 @@ def get_property_value(self, augmented_mol: Dict): return prop_list - @abc.abstractmethod + @abstractmethod def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): pass @@ -146,6 +145,10 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): class AtomFunctionalGroup(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) + + # To avoid circular imports + from chebai_graph.preprocessing.reader import RuleBasedFGReader + self.fg_reader = RuleBasedFGReader() def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): From 45ba894574d313fbdc58897bca91ac24b358ecc9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 19:12:41 +0200 Subject: [PATCH 025/224] right order of import in properties init --- chebai_graph/preprocessing/properties/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 525d248..8bb6953 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -1,2 +1,7 @@ +# Formating is turned off here, because isort sorts the augmented properties imports in first order, +# but it has to be imported after properties module, to avoid circular imports +# fmt: off from .augmented_properties import * from .properties import * + +# fmt: on From 96232ec8c48805fa53db88025c7d556f894a7f0a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 19:13:27 +0200 Subject: [PATCH 026/224] remove graph fg check in augmentor reader --- chebai_graph/preprocessing/reader.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 05585b9..6a048b1 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -152,12 +152,6 @@ def __init__( self.failed_counter = 0 self.mol_object_buffer = {} - if "graph_fg" not in self.cache: - raise KeyError( - f"Function group `graph_fg` doesn't exits in {self.token_path}. " - f"It should be manually added to token file (preferably at 0th index)" - ) - @classmethod def name(cls): return "graph_fg_augmentor" From 52cc60a8393b264ed7ef6d6a13de898a7399d81b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 19:16:59 +0200 Subject: [PATCH 027/224] right order of imports --- chebai_graph/preprocessing/properties/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 8bb6953..853d02c 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -1,7 +1,8 @@ # Formating is turned off here, because isort sorts the augmented properties imports in first order, # but it has to be imported after properties module, to avoid circular imports -# fmt: off -from .augmented_properties import * +# This is because augmented properties module imports from properties module +# isort: off from .properties import * +from .augmented_properties import * -# fmt: on +# isort: on From 83aa7f998afc925b33ecd4794796820a782eca75 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 21:42:00 +0200 Subject: [PATCH 028/224] make rings as a Functional Group --- chebai_graph/preprocessing/reader.py | 45 ++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 6a048b1..40f9d6d 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -232,13 +232,48 @@ def _augment_graph(self, mol: Mol): fg_atom_edges[f"{num_of_nodes}_{atom}"] = {EDGE_LEVEL: ATOM_FG_EDGE} num_of_edges += 1 - any_atom = next(iter(structure[fg]["atom"][0])) # any atom related to fg - fg_nodes[num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, - "FG": any_atom.GetProp("FG"), - "RING": any_atom.GetProp("RING"), + fg_set = { + mol.GetAtomWithIdx(atom_idx).GetProp("FG") + for atom_idx in structure[fg]["atom"] + if mol.GetAtomWithIdx(atom_idx).GetProp("FG") } + if len(fg_set) > 1: + raise Exception("connected atoms should belong to only one fg") + + elif len(fg_set) == 0: + ring_sizes = set() + for atom_idx in structure[fg]["atom"]: + atom = mol.GetAtomWithIdx(atom_idx) + ring_size_prop = atom.GetProp("RING") + if not ring_size_prop: + raise Exception("All atoms should have ring size") + ring_sizes.add(int(ring_size_prop)) + atom.SetProp("FG", f"RING_{ring_size_prop}") + + # TODO: Incase error is raised check logic for fused rings + assert len(ring_sizes) == 1, "all atoms should have one ring size" + ring_size = list(ring_sizes)[0] + + fg_nodes[num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": f"RING_{ring_size}", + "RING": ring_size, + } + + else: + any_atom = None + for atom_idx in structure[fg]["atom"]: + atom = mol.GetAtomWithIdx(atom_idx) + if atom.GetProp("FG"): + any_atom = atom + + fg_nodes[num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": any_atom.GetProp("FG"), + "RING": any_atom.GetProp("RING"), + } + num_of_nodes += 1 fg_edges = {} From 94d9c088db5e0b3d3c99aae8ed6b850854b39b19 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Apr 2025 21:42:41 +0200 Subject: [PATCH 029/224] utility to visualize augmented molecule --- .../utils/plot_augmented_graph.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 chebai_graph/preprocessing/utils/plot_augmented_graph.py diff --git a/chebai_graph/preprocessing/utils/plot_augmented_graph.py b/chebai_graph/preprocessing/utils/plot_augmented_graph.py new file mode 100644 index 0000000..2920b23 --- /dev/null +++ b/chebai_graph/preprocessing/utils/plot_augmented_graph.py @@ -0,0 +1,71 @@ +import matplotlib +import networkx as nx + +from chebai_graph.preprocessing.reader import GraphFGAugmentorReader + +matplotlib.use("TkAgg") # or "QtAgg", if you have PyQt/PySide installed +import matplotlib.pyplot as plt # noqa + + +def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): + G = nx.Graph() + + # Node labels and types for visualization + node_labels = {} + node_colors = [] + + # Add atom nodes + atom_nodes = augmented_graph_nodes["atom_nodes"] + for atom in atom_nodes.GetAtoms(): + idx = atom.GetIdx() + label = atom.GetSymbol() + G.add_node(idx) + node_labels[idx] = label + node_colors.append("lightblue") + + # Add functional group nodes + fg_nodes = augmented_graph_nodes["fg_nodes"] + for fg_idx, fg_props in fg_nodes.items(): + label = f"FG:{fg_props['FG']}" + G.add_node(fg_idx) + node_labels[fg_idx] = label + node_colors.append("orange") + + # Add graph-level node + graph_node_idx = augmented_graph_nodes["num_nodes"] + G.add_node(graph_node_idx) + node_labels[graph_node_idx] = "Graph Node" + node_colors.append("red") + + # Add edges + src_nodes, tgt_nodes = edge_index.tolist() + for src, tgt in zip(src_nodes, tgt_nodes): + G.add_edge(src, tgt) + + # Plot the graph + plt.figure(figsize=(10, 8)) + pos = nx.spring_layout(G, seed=42) + nx.draw( + G, + pos, + with_labels=False, + node_color=node_colors, + node_size=600, + edge_color="gray", + ) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + plt.title("Augmented Molecular Graph") + plt.axis("off") + plt.show() + + +def main(smiles: str): + reader = GraphFGAugmentorReader() + mol = reader._smiles_to_mol(smiles) + edge_index, augmented_nodes, augmented_edges = reader._augment_graph(mol) + plot_augmented_graph(edge_index, augmented_nodes, augmented_edges) + + +if __name__ == "__main__": + smiles = "OC(=O)c1ccccc1O" + main(smiles) From 49d85c5845d47e734923f0e7827d1b6513b5e4c4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 00:30:25 +0200 Subject: [PATCH 030/224] rename fg_graph edge constant --- .../preprocessing/properties/augmented_properties.py | 4 ++-- chebai_graph/preprocessing/properties/constants.py | 4 ++-- chebai_graph/preprocessing/reader.py | 6 ++---- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index a1e870d..3b7833c 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -31,7 +31,7 @@ def get_property_value(self, augmented_mol: Dict): fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] - fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_LEVEL] + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE] if ( not isinstance(fg_atom_edges, dict) @@ -39,7 +39,7 @@ def get_property_value(self, augmented_mol: Dict): or not isinstance(fg_graph_node_edges, dict) ): raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_LEVEL}"]) ' + f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) ' f"must be an instance of dict containing its properties" ) diff --git a/chebai_graph/preprocessing/properties/constants.py b/chebai_graph/preprocessing/properties/constants.py index fef6cfd..67de13a 100644 --- a/chebai_graph/preprocessing/properties/constants.py +++ b/chebai_graph/preprocessing/properties/constants.py @@ -8,5 +8,5 @@ WITHIN_ATOMS_EDGE = "within_atoms_lvl" WITHIN_FG_EDGE = "within_fg_lvl" ATOM_FG_EDGE = "atom_fg_lvl" -FG_GRAPHNODE_LEVEL = "fg_graphNode_lvl" -EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_LEVEL} +FG_GRAPHNODE_EDGE = "fg_graphNode_lvl" +EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_EDGE} diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 40f9d6d..fbc96ec 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -302,9 +302,7 @@ def _augment_graph(self, mol: Mol): for fg in new_structure.keys(): global_node_edge_index[0].extend([num_of_nodes, fg]) global_node_edge_index[1].extend([fg, num_of_nodes]) - fg_graphNode_edges[f"{num_of_nodes}_{fg}"] = { - NODE_LEVEL: FG_GRAPHNODE_LEVEL - } + fg_graphNode_edges[f"{num_of_nodes}_{fg}"] = {NODE_LEVEL: FG_GRAPHNODE_EDGE} num_of_edges += 1 all_edges = torch.cat( @@ -327,7 +325,7 @@ def _augment_graph(self, mol: Mol): WITHIN_ATOMS_EDGE: mol, ATOM_FG_EDGE: fg_atom_edges, WITHIN_FG_EDGE: fg_edges, - FG_GRAPHNODE_LEVEL: fg_graphNode_edges, + FG_GRAPHNODE_EDGE: fg_graphNode_edges, "num_edges": num_of_edges, } From 8805ecdbf1c2b61a74c71dd8c80ea784b7b235a0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 00:34:46 +0200 Subject: [PATCH 031/224] add diff color to diff edge type --- .../utils/plot_augmented_graph.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/chebai_graph/preprocessing/utils/plot_augmented_graph.py b/chebai_graph/preprocessing/utils/plot_augmented_graph.py index 2920b23..b351dbe 100644 --- a/chebai_graph/preprocessing/utils/plot_augmented_graph.py +++ b/chebai_graph/preprocessing/utils/plot_augmented_graph.py @@ -1,6 +1,7 @@ import matplotlib import networkx as nx +from chebai_graph.preprocessing.properties.constants import * from chebai_graph.preprocessing.reader import GraphFGAugmentorReader matplotlib.use("TkAgg") # or "QtAgg", if you have PyQt/PySide installed @@ -18,31 +19,59 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge atom_nodes = augmented_graph_nodes["atom_nodes"] for atom in atom_nodes.GetAtoms(): idx = atom.GetIdx() - label = atom.GetSymbol() G.add_node(idx) - node_labels[idx] = label - node_colors.append("lightblue") + node_labels[idx] = atom.GetSymbol() + node_colors.append("#9ecae1") # soft blue # Add functional group nodes fg_nodes = augmented_graph_nodes["fg_nodes"] for fg_idx, fg_props in fg_nodes.items(): - label = f"FG:{fg_props['FG']}" G.add_node(fg_idx) - node_labels[fg_idx] = label - node_colors.append("orange") + node_labels[fg_idx] = f"FG:{fg_props['FG']}" + node_colors.append("#fdae6b") # orange # Add graph-level node graph_node_idx = augmented_graph_nodes["num_nodes"] G.add_node(graph_node_idx) node_labels[graph_node_idx] = "Graph Node" - node_colors.append("red") + node_colors.append("#d62728") # red # Add edges src_nodes, tgt_nodes = edge_index.tolist() + + with_atom_edges = { + f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" + for bond in augmented_graph_edges[WITHIN_ATOMS_EDGE].GetBonds() + } + atom_fg_edges = set(augmented_graph_edges[ATOM_FG_EDGE]) + within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) + fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) + + edge_colors = [] + edge_color_map = { + WITHIN_ATOMS_EDGE: "#1f77b4", # blue + ATOM_FG_EDGE: "#ff7f0e", # orange + WITHIN_FG_EDGE: "#ffbb78", # light orange + FG_GRAPHNODE_EDGE: "#2ca02c", # green + } + for src, tgt in zip(src_nodes, tgt_nodes): + undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} + + if undirected_edge_set & with_atom_edges: + edge_type = WITHIN_ATOMS_EDGE + elif undirected_edge_set & atom_fg_edges: + edge_type = ATOM_FG_EDGE + elif undirected_edge_set & within_fg_edges: + edge_type = WITHIN_FG_EDGE + elif undirected_edge_set & fg_graph_edges: + edge_type = FG_GRAPHNODE_EDGE + else: + raise Exception("Unexpected edge type") + G.add_edge(src, tgt) + edge_colors.append(edge_color_map[edge_type]) - # Plot the graph plt.figure(figsize=(10, 8)) pos = nx.spring_layout(G, seed=42) nx.draw( @@ -51,7 +80,8 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge with_labels=False, node_color=node_colors, node_size=600, - edge_color="gray", + edge_color=edge_colors, + width=2, ) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) plt.title("Augmented Molecular Graph") From 0ff4e86b7c2239da666abbdb32b9615557c8c35e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 00:38:15 +0200 Subject: [PATCH 032/224] Update visualize_augmented_molecule.py --- .../{plot_augmented_graph.py => visualize_augmented_molecule.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename chebai_graph/preprocessing/utils/{plot_augmented_graph.py => visualize_augmented_molecule.py} (100%) diff --git a/chebai_graph/preprocessing/utils/plot_augmented_graph.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py similarity index 100% rename from chebai_graph/preprocessing/utils/plot_augmented_graph.py rename to chebai_graph/preprocessing/utils/visualize_augmented_molecule.py From a78fe249ef36ef848887e40d03bc25a39e4a5cf8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 12:31:32 +0200 Subject: [PATCH 033/224] fix edge color mismatch bug --- .../utils/visualize_augmented_molecule.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index b351dbe..9121e9c 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -36,6 +36,9 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge node_labels[graph_node_idx] = "Graph Node" node_colors.append("#d62728") # red + # Ensure edge_index is undirected by converting it to undirected + edge_index = edge_index + # Add edges src_nodes, tgt_nodes = edge_index.tolist() @@ -47,14 +50,13 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) - edge_colors = [] edge_color_map = { WITHIN_ATOMS_EDGE: "#1f77b4", # blue - ATOM_FG_EDGE: "#ff7f0e", # orange - WITHIN_FG_EDGE: "#ffbb78", # light orange + ATOM_FG_EDGE: "#9467bd", # purple + WITHIN_FG_EDGE: "#ff7f0e", # orange FG_GRAPHNODE_EDGE: "#2ca02c", # green } - + augmented_edges = [] for src, tgt in zip(src_nodes, tgt_nodes): undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} @@ -69,8 +71,9 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge else: raise Exception("Unexpected edge type") - G.add_edge(src, tgt) - edge_colors.append(edge_color_map[edge_type]) + augmented_edges.append((src, tgt, {"type": edge_type})) + + G.add_edges_from(augmented_edges) plt.figure(figsize=(10, 8)) pos = nx.spring_layout(G, seed=42) @@ -80,7 +83,7 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge with_labels=False, node_color=node_colors, node_size=600, - edge_color=edge_colors, + edge_color=[edge_color_map[data["type"]] for _, _, data in G.edges(data=True)], width=2, ) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) From 52c106b045e8421dbbab8d0a2c544995281997fd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 14:03:12 +0200 Subject: [PATCH 034/224] better position alignment of the nodes --- .../utils/visualize_augmented_molecule.py | 103 +++++++++++------- 1 file changed, 65 insertions(+), 38 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 9121e9c..18f2757 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -1,47 +1,48 @@ import matplotlib +import matplotlib.pyplot as plt import networkx as nx +from rdkit.Chem import AllChem from chebai_graph.preprocessing.properties.constants import * from chebai_graph.preprocessing.reader import GraphFGAugmentorReader -matplotlib.use("TkAgg") # or "QtAgg", if you have PyQt/PySide installed -import matplotlib.pyplot as plt # noqa +matplotlib.use("TkAgg") -def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): +def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edges, mol): G = nx.Graph() - # Node labels and types for visualization node_labels = {} node_colors = [] + node_type_map = {} - # Add atom nodes atom_nodes = augmented_graph_nodes["atom_nodes"] + atom_ids = [] for atom in atom_nodes.GetAtoms(): idx = atom.GetIdx() G.add_node(idx) node_labels[idx] = atom.GetSymbol() - node_colors.append("#9ecae1") # soft blue + node_colors.append("#9ecae1") + node_type_map[idx] = "atom" + atom_ids.append(idx) - # Add functional group nodes fg_nodes = augmented_graph_nodes["fg_nodes"] + fg_ids = [] for fg_idx, fg_props in fg_nodes.items(): G.add_node(fg_idx) node_labels[fg_idx] = f"FG:{fg_props['FG']}" - node_colors.append("#fdae6b") # orange + node_colors.append("#fdae6b") + node_type_map[fg_idx] = "fg" + fg_ids.append(fg_idx) - # Add graph-level node graph_node_idx = augmented_graph_nodes["num_nodes"] G.add_node(graph_node_idx) node_labels[graph_node_idx] = "Graph Node" - node_colors.append("#d62728") # red + node_colors.append("#d62728") + node_type_map[graph_node_idx] = "graph" + graph_ids = [graph_node_idx] - # Ensure edge_index is undirected by converting it to undirected - edge_index = edge_index - - # Add edges src_nodes, tgt_nodes = edge_index.tolist() - with_atom_edges = { f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" for bond in augmented_graph_edges[WITHIN_ATOMS_EDGE].GetBonds() @@ -51,15 +52,16 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) edge_color_map = { - WITHIN_ATOMS_EDGE: "#1f77b4", # blue - ATOM_FG_EDGE: "#9467bd", # purple - WITHIN_FG_EDGE: "#ff7f0e", # orange - FG_GRAPHNODE_EDGE: "#2ca02c", # green + WITHIN_ATOMS_EDGE: "#1f77b4", + ATOM_FG_EDGE: "#9467bd", + WITHIN_FG_EDGE: "#ff7f0e", + FG_GRAPHNODE_EDGE: "#2ca02c", } - augmented_edges = [] + + edges = [] + edge_colors = [] for src, tgt in zip(src_nodes, tgt_nodes): undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} - if undirected_edge_set & with_atom_edges: edge_type = WITHIN_ATOMS_EDGE elif undirected_edge_set & atom_fg_edges: @@ -70,24 +72,49 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge edge_type = FG_GRAPHNODE_EDGE else: raise Exception("Unexpected edge type") - - augmented_edges.append((src, tgt, {"type": edge_type})) - - G.add_edges_from(augmented_edges) - + edges.append((src, tgt)) + edge_colors.append(edge_color_map[edge_type]) + G.add_edges_from(edges) + + # 1. Get atom positions from RDKit + AllChem.Compute2DCoords(mol) + atom_pos, max_atom_pos_y = {}, 0 + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + pos = mol.GetConformer().GetAtomPosition(idx) + atom_pos[idx] = (pos.x, pos.y) # Flip y-axis so graph node is on top + if pos.y > max_atom_pos_y: + max_atom_pos_y = pos.y + + # 2. Layout for FG and Graph nodes + fg_subgraph = G.subgraph(fg_ids) + fg_pos = nx.spring_layout(fg_subgraph, seed=42) + fg_pos = { + node: (x, y + max_atom_pos_y + 2) for node, (x, y) in fg_pos.items() + } # Below atoms + + graph_node_subgraph = G.subgraph(graph_ids) + graph_pos = nx.spring_layout(graph_node_subgraph, seed=123) + graph_pos = { + node: (x, y + max_atom_pos_y + 3) for node, (x, y) in graph_pos.items() + } # Above atoms + + # Combine all positions + pos = {**atom_pos, **fg_pos, **graph_pos} + + # Final node color mapping + node_colors_final = [ + {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[node_type_map[n]] + for n in G.nodes + ] + + # Draw plt.figure(figsize=(10, 8)) - pos = nx.spring_layout(G, seed=42) - nx.draw( - G, - pos, - with_labels=False, - node_color=node_colors, - node_size=600, - edge_color=[edge_color_map[data["type"]] for _, _, data in G.edges(data=True)], - width=2, - ) + nx.draw_networkx_nodes(G, pos, node_color=node_colors_final, node_size=600) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) - plt.title("Augmented Molecular Graph") + nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) + + plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") plt.axis("off") plt.show() @@ -96,7 +123,7 @@ def main(smiles: str): reader = GraphFGAugmentorReader() mol = reader._smiles_to_mol(smiles) edge_index, augmented_nodes, augmented_edges = reader._augment_graph(mol) - plot_augmented_graph(edge_index, augmented_nodes, augmented_edges) + plot_augmented_graph(edge_index, augmented_nodes, augmented_edges, mol) if __name__ == "__main__": From 5eb462cd94d59600f2f48ceccbcdee6f98dbdb19 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 14:12:15 +0200 Subject: [PATCH 035/224] add CLI --- .../utils/visualize_augmented_molecule.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 18f2757..e4860ea 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -1,6 +1,7 @@ import matplotlib import matplotlib.pyplot as plt import networkx as nx +from jsonargparse import CLI from rdkit.Chem import AllChem from chebai_graph.preprocessing.properties.constants import * @@ -119,13 +120,18 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge plt.show() -def main(smiles: str): - reader = GraphFGAugmentorReader() - mol = reader._smiles_to_mol(smiles) - edge_index, augmented_nodes, augmented_edges = reader._augment_graph(mol) - plot_augmented_graph(edge_index, augmented_nodes, augmented_edges, mol) +class Main: + def __init__(self): + self._fg_reader = GraphFGAugmentorReader() + + def plot(self, smiles: str = "OC(=O)c1ccccc1O"): + mol = self._fg_reader._smiles_to_mol(smiles) # noqa + edge_index, augmented_nodes, augmented_edges = self._fg_reader._augment_graph( + mol + ) # noqa + plot_augmented_graph(edge_index, augmented_nodes, augmented_edges, mol) if __name__ == "__main__": - smiles = "OC(=O)c1ccccc1O" - main(smiles) + # use:- visualize_augmented_molecule.py plot --smiles="OC(=O)c1ccccc1O" + CLI(Main) From 56b970fd0be6f7b421393ed8414adddc9a412ae4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 15:54:34 +0200 Subject: [PATCH 036/224] plot based on given plot type --- .../utils/visualize_augmented_molecule.py | 116 +++++++++++------- 1 file changed, 71 insertions(+), 45 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index e4860ea..28cf72e 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -10,7 +10,9 @@ matplotlib.use("TkAgg") -def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edges, mol): +def plot_augmented_graph( + edge_index, augmented_graph_nodes, augmented_graph_edges, mol, plot_type +): G = nx.Graph() node_labels = {} @@ -73,63 +75,87 @@ def plot_augmented_graph(edge_index, augmented_graph_nodes, augmented_graph_edge edge_type = FG_GRAPHNODE_EDGE else: raise Exception("Unexpected edge type") - edges.append((src, tgt)) + edges.append((src, tgt, {"type": edge_type})) edge_colors.append(edge_color_map[edge_type]) G.add_edges_from(edges) - # 1. Get atom positions from RDKit - AllChem.Compute2DCoords(mol) - atom_pos, max_atom_pos_y = {}, 0 - for atom in mol.GetAtoms(): - idx = atom.GetIdx() - pos = mol.GetConformer().GetAtomPosition(idx) - atom_pos[idx] = (pos.x, pos.y) # Flip y-axis so graph node is on top - if pos.y > max_atom_pos_y: - max_atom_pos_y = pos.y - - # 2. Layout for FG and Graph nodes - fg_subgraph = G.subgraph(fg_ids) - fg_pos = nx.spring_layout(fg_subgraph, seed=42) - fg_pos = { - node: (x, y + max_atom_pos_y + 2) for node, (x, y) in fg_pos.items() - } # Below atoms - - graph_node_subgraph = G.subgraph(graph_ids) - graph_pos = nx.spring_layout(graph_node_subgraph, seed=123) - graph_pos = { - node: (x, y + max_atom_pos_y + 3) for node, (x, y) in graph_pos.items() - } # Above atoms - - # Combine all positions - pos = {**atom_pos, **fg_pos, **graph_pos} - - # Final node color mapping - node_colors_final = [ - {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[node_type_map[n]] - for n in G.nodes - ] - - # Draw - plt.figure(figsize=(10, 8)) - nx.draw_networkx_nodes(G, pos, node_color=node_colors_final, node_size=600) - nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) - nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) - - plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") - plt.axis("off") - plt.show() + if plot_type == "h": # hierarchy + # 1. Get atom positions from RDKit + AllChem.Compute2DCoords(mol) + atom_pos, max_atom_pos_y = {}, 0 + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + pos = mol.GetConformer().GetAtomPosition(idx) + atom_pos[idx] = (pos.x, pos.y) # Flip y-axis so graph node is on top + if pos.y > max_atom_pos_y: + max_atom_pos_y = pos.y + + # 2. Layout for FG and Graph nodes + fg_subgraph = G.subgraph(fg_ids) + fg_pos = nx.spring_layout(fg_subgraph, seed=42) + fg_pos = { + node: (x, y + max_atom_pos_y + 2) for node, (x, y) in fg_pos.items() + } # Below atoms + + graph_node_subgraph = G.subgraph(graph_ids) + graph_pos = nx.spring_layout(graph_node_subgraph, seed=123) + graph_pos = { + node: (x, y + max_atom_pos_y + 3) for node, (x, y) in graph_pos.items() + } # Above atoms + + # Combine all positions + pos = {**atom_pos, **fg_pos, **graph_pos} + + # Final node color mapping + node_colors_final = [ + {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[node_type_map[n]] + for n in G.nodes + ] + + # Draw + plt.figure(figsize=(10, 8)) + nx.draw_networkx_nodes(G, pos, node_color=node_colors_final, node_size=600) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) + + plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") + plt.axis("off") + plt.show() + + elif plot_type == "simple": + plt.figure(figsize=(10, 8)) + pos = nx.spring_layout(G, seed=42) + nx.draw( + G, + pos, + with_labels=False, + node_color=node_colors, + node_size=600, + edge_color=[ + edge_color_map[data["type"]] for _, _, data in G.edges(data=True) + ], + width=2, + ) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + plt.title("Augmented Graph with simple layout") + plt.axis("off") + plt.show() + else: + raise Exception("Unknown plot type") class Main: def __init__(self): self._fg_reader = GraphFGAugmentorReader() - def plot(self, smiles: str = "OC(=O)c1ccccc1O"): + def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple"): mol = self._fg_reader._smiles_to_mol(smiles) # noqa edge_index, augmented_nodes, augmented_edges = self._fg_reader._augment_graph( mol ) # noqa - plot_augmented_graph(edge_index, augmented_nodes, augmented_edges, mol) + plot_augmented_graph( + edge_index, augmented_nodes, augmented_edges, mol, plot_type + ) if __name__ == "__main__": From 13f21b20f3c56d9d7de5d8e55dbc83bebf664f0a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 21:07:29 +0200 Subject: [PATCH 037/224] Update .pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ee15ba..108b91d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: hooks: - id: check-yaml - id: end-of-file-fixer - - id: trailing-whitespace \ No newline at end of file + - id: trailing-whitespace From f2e280318934261ed36c11588e3872f78ff95b13 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 21:07:48 +0200 Subject: [PATCH 038/224] add 3d plot visualization --- .../utils/visualize_augmented_molecule.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 28cf72e..31ceaa0 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -140,6 +140,98 @@ def plot_augmented_graph( plt.title("Augmented Graph with simple layout") plt.axis("off") plt.show() + + elif plot_type == "3d": + from plotly import graph_objects as go + + # Compute 3D coordinates for atoms + AllChem.EmbedMolecule(mol) + conf = mol.GetConformer() + + atom_pos = {} + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + pos = conf.GetAtomPosition(idx) + atom_pos[idx] = (pos.x, pos.y, 0) # pos.z + + # Generate 3D layout for FG and Graph nodes using spring layout + fg_pos_3d = nx.spring_layout(G.subgraph(fg_ids), seed=42, dim=3) + graph_pos_3d = nx.spring_layout(G.subgraph(graph_ids), seed=123, dim=3) + + # Offset to avoid overlap with atom layer + max_z = 0 # max(z for _, (_, _, z) in atom_pos.items()) if atom_pos else 0 + fg_pos = {k: (x, y, z + max_z + 2) for k, (x, y, z) in fg_pos_3d.items()} + graph_pos = {k: (x, y, z + max_z + 4) for k, (x, y, z) in graph_pos_3d.items()} + pos = {**atom_pos, **fg_pos, **graph_pos} + + # Group edges by type + edge_type_to_edges = { + WITHIN_ATOMS_EDGE: [], + ATOM_FG_EDGE: [], + WITHIN_FG_EDGE: [], + FG_GRAPHNODE_EDGE: [], + } + + for src, tgt, data in edges: + edge_type_to_edges[data["type"]].append((src, tgt)) + + # Create edge traces + edge_traces = [] + for edge_type, edge_list in edge_type_to_edges.items(): + xs, ys, zs = [], [], [] + for src, tgt in edge_list: + x0, y0, z0 = pos[src] + x1, y1, z1 = pos[tgt] + xs += [x0, x1, None] + ys += [y0, y1, None] + zs += [z0, z1, None] + + trace = go.Scatter3d( + x=xs, + y=ys, + z=zs, + mode="lines", + line=dict(color=edge_color_map[edge_type], width=4), + name=edge_type, + hoverinfo="none", + ) + edge_traces.append(trace) + + # Node trace + + node_trace = go.Scatter3d( + x=[pos[n][0] for n in G.nodes], + y=[pos[n][1] for n in G.nodes], + z=[pos[n][2] for n in G.nodes], + mode="markers+text", + marker=dict( + size=8, + color=[ + {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[ + node_type_map[n] + ] + for n in G.nodes + ], + opacity=0.9, + ), + text=[node_labels[n] for n in G.nodes], + textposition="top center", + hoverinfo="text", + ) + + # Combine all traces and plot + fig = go.Figure(data=edge_traces + [node_trace]) + fig.update_layout( + title="3D Augmented Molecule Graph", + showlegend=True, + scene=dict( + xaxis=dict(visible=False), + yaxis=dict(visible=False), + zaxis=dict(visible=False), + ), + margin=dict(l=0, r=0, b=0, t=40), + ) + fig.show() else: raise Exception("Unknown plot type") From a9a21deb08b504261f2f2fe959a969790fb7ed49 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 21:09:17 +0200 Subject: [PATCH 039/224] pre-commit formatting --- README.md | 4 +-- chebai_graph/models/gin_net.py | 9 +++--- chebai_graph/preprocessing/collate.py | 4 +-- chebai_graph/preprocessing/datasets/chebi.py | 32 ++++++++++--------- .../preprocessing/datasets/pubchem.py | 3 +- .../preprocessing/property_encoder.py | 3 +- .../preprocessing/transform_unlabeled.py | 1 + configs/data/chebi50_graph.yml | 2 +- configs/data/pubchem_graph.yml | 2 +- configs/loss/mask_pretraining.yml | 2 +- configs/model/gnn.yml | 2 +- configs/model/gnn_attention.yml | 2 +- configs/model/gnn_gine.yml | 2 +- configs/model/gnn_res_gated.yml | 2 +- configs/model/gnn_resgated_pretrain.yml | 2 +- pyproject.toml | 2 +- 16 files changed, 40 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index e446b82..8c8f9c2 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ ## Installation -Some requirements may not be installed successfully automatically. +Some requirements may not be installed successfully automatically. To install the `torch-` libraries, use `pip install torch-${lib} -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html` where `${lib}` is either `scatter`, `geometric`, `sparse` or `cluster`, and -`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. +`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. [torch-geometric docs](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)) diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 6ee8bc5..843ba35 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,26 +1,26 @@ -from typing import Optional, List, Callable +import importlib +import os +from typing import Callable, List, Optional +import pandas as pd +import torch +import tqdm +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, ChEBIOverXPartial, ) -from chebai.preprocessing.datasets.base import XYBaseDataModule from lightning_utilities.core.rank_zero import rank_zero_info +from torch_geometric.data.data import Data as GeomData -from chebai_graph.preprocessing.reader import GraphReader, GraphPropertyReader +import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( AtomProperty, BondProperty, MolecularProperty, ) -import pandas as pd -from torch_geometric.data.data import Data as GeomData -import torch -import chebai_graph.preprocessing.properties as graph_properties -import importlib -import os -import tqdm +from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader class ChEBI50GraphData(ChEBIOver50): @@ -84,9 +84,11 @@ def _setup_properties(self): for file in file_names: # processed_dir_main only exists for ChEBI datasets path = os.path.join( - self.processed_dir_main - if hasattr(self, "processed_dir_main") - else self.raw_dir, + ( + self.processed_dir_main + if hasattr(self, "processed_dir_main") + else self.raw_dir + ), file, ) raw_data += list(self._load_dict(path)) @@ -94,8 +96,8 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = ( - lambda encode, value: [encode(atom_v) for atom_v in value] + enc_if_not_none = lambda encode, value: ( + [encode(atom_v) for atom_v in value] if value is not None and len(value) > 0 else None ) diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 497025c..ebfbe0c 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,8 +1,9 @@ import abc import os -import torch from typing import Optional +import torch + class PropertyEncoder(abc.ABC): def __init__(self, property, **kwargs): diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..27d1e78 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -10,4 +10,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index 64c572c..4aea1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models' From 368585bcbf429161141e61e343f9d9ed1b2cb228 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 21:11:16 +0200 Subject: [PATCH 040/224] ploty lib is a soft requirement, use dynamic import error --- .../preprocessing/utils/visualize_augmented_molecule.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 31ceaa0..0db9c30 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -142,7 +142,13 @@ def plot_augmented_graph( plt.show() elif plot_type == "3d": - from plotly import graph_objects as go + try: + from plotly import graph_objects as go + except ImportError: + raise ImportError( + "Plotly is required for 3D plotting. Please install it using:\n\n" + " pip install plotly" + ) # Compute 3D coordinates for atoms AllChem.EmbedMolecule(mol) From 38875c05021209cb25857ffe024a45b5314ed1fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 21:22:45 +0200 Subject: [PATCH 041/224] modularize plotting code --- .../utils/visualize_augmented_molecule.py | 351 +++++++++--------- 1 file changed, 182 insertions(+), 169 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 0db9c30..174719e 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -9,12 +9,22 @@ matplotlib.use("TkAgg") +EDGE_COLOR_MAP = { + WITHIN_ATOMS_EDGE: "#1f77b4", + ATOM_FG_EDGE: "#9467bd", + WITHIN_FG_EDGE: "#ff7f0e", + FG_GRAPHNODE_EDGE: "#2ca02c", +} -def plot_augmented_graph( - edge_index, augmented_graph_nodes, augmented_graph_edges, mol, plot_type -): - G = nx.Graph() +NODE_COLOR_MAP = { + "atom": "#9ecae1", + "fg": "#fdae6b", + "graph": "#d62728", +} + +def _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): + G = nx.Graph() node_labels = {} node_colors = [] node_type_map = {} @@ -25,7 +35,7 @@ def plot_augmented_graph( idx = atom.GetIdx() G.add_node(idx) node_labels[idx] = atom.GetSymbol() - node_colors.append("#9ecae1") + node_colors.append(NODE_COLOR_MAP["atom"]) node_type_map[idx] = "atom" atom_ids.append(idx) @@ -34,14 +44,14 @@ def plot_augmented_graph( for fg_idx, fg_props in fg_nodes.items(): G.add_node(fg_idx) node_labels[fg_idx] = f"FG:{fg_props['FG']}" - node_colors.append("#fdae6b") + node_colors.append(NODE_COLOR_MAP["fg"]) node_type_map[fg_idx] = "fg" fg_ids.append(fg_idx) graph_node_idx = augmented_graph_nodes["num_nodes"] G.add_node(graph_node_idx) node_labels[graph_node_idx] = "Graph Node" - node_colors.append("#d62728") + node_colors.append(NODE_COLOR_MAP["graph"]) node_type_map[graph_node_idx] = "graph" graph_ids = [graph_node_idx] @@ -54,13 +64,6 @@ def plot_augmented_graph( within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) - edge_color_map = { - WITHIN_ATOMS_EDGE: "#1f77b4", - ATOM_FG_EDGE: "#9467bd", - WITHIN_FG_EDGE: "#ff7f0e", - FG_GRAPHNODE_EDGE: "#2ca02c", - } - edges = [] edge_colors = [] for src, tgt in zip(src_nodes, tgt_nodes): @@ -76,170 +79,180 @@ def plot_augmented_graph( else: raise Exception("Unexpected edge type") edges.append((src, tgt, {"type": edge_type})) - edge_colors.append(edge_color_map[edge_type]) + edge_colors.append(EDGE_COLOR_MAP[edge_type]) + G.add_edges_from(edges) + return ( + G, + node_labels, + node_colors, + node_type_map, + edges, + edge_colors, + atom_ids, + fg_ids, + graph_ids, + ) + + +def _draw_hierarchy( + G, mol, node_labels, node_type_map, edges, edge_colors, fg_ids, graph_ids +): + AllChem.Compute2DCoords(mol) + atom_pos, max_atom_pos_y = {}, 0 + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + pos = mol.GetConformer().GetAtomPosition(idx) + atom_pos[idx] = (pos.x, pos.y) + max_atom_pos_y = max(max_atom_pos_y, pos.y) - if plot_type == "h": # hierarchy - # 1. Get atom positions from RDKit - AllChem.Compute2DCoords(mol) - atom_pos, max_atom_pos_y = {}, 0 - for atom in mol.GetAtoms(): - idx = atom.GetIdx() - pos = mol.GetConformer().GetAtomPosition(idx) - atom_pos[idx] = (pos.x, pos.y) # Flip y-axis so graph node is on top - if pos.y > max_atom_pos_y: - max_atom_pos_y = pos.y - - # 2. Layout for FG and Graph nodes - fg_subgraph = G.subgraph(fg_ids) - fg_pos = nx.spring_layout(fg_subgraph, seed=42) - fg_pos = { - node: (x, y + max_atom_pos_y + 2) for node, (x, y) in fg_pos.items() - } # Below atoms - - graph_node_subgraph = G.subgraph(graph_ids) - graph_pos = nx.spring_layout(graph_node_subgraph, seed=123) - graph_pos = { - node: (x, y + max_atom_pos_y + 3) for node, (x, y) in graph_pos.items() - } # Above atoms - - # Combine all positions - pos = {**atom_pos, **fg_pos, **graph_pos} - - # Final node color mapping - node_colors_final = [ - {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[node_type_map[n]] - for n in G.nodes - ] - - # Draw - plt.figure(figsize=(10, 8)) - nx.draw_networkx_nodes(G, pos, node_color=node_colors_final, node_size=600) - nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) - nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) - - plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") - plt.axis("off") - plt.show() + fg_pos = { + node: (x, y + max_atom_pos_y + 2) + for node, (x, y) in nx.spring_layout(G.subgraph(fg_ids), seed=42).items() + } + graph_pos = { + node: (x, y + max_atom_pos_y + 3) + for node, (x, y) in nx.spring_layout(G.subgraph(graph_ids), seed=123).items() + } - elif plot_type == "simple": - plt.figure(figsize=(10, 8)) - pos = nx.spring_layout(G, seed=42) - nx.draw( - G, - pos, - with_labels=False, - node_color=node_colors, - node_size=600, - edge_color=[ - edge_color_map[data["type"]] for _, _, data in G.edges(data=True) - ], - width=2, + pos = {**atom_pos, **fg_pos, **graph_pos} + node_colors = [NODE_COLOR_MAP[node_type_map[n]] for n in G.nodes] + + plt.figure(figsize=(10, 8)) + nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=600) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) + plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") + plt.axis("off") + plt.show() + + +def _draw_simple(G, node_labels, node_colors): + plt.figure(figsize=(10, 8)) + pos = nx.spring_layout(G, seed=42) + nx.draw( + G, + pos, + with_labels=False, + node_color=node_colors, + node_size=600, + edge_color=[EDGE_COLOR_MAP[data["type"]] for _, _, data in G.edges(data=True)], + width=2, + ) + nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) + plt.title("Augmented Graph with simple layout") + plt.axis("off") + plt.show() + + +def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): + try: + from plotly import graph_objects as go + except ImportError: + raise ImportError( + "Plotly is required for 3D plotting. Install it with `pip install plotly`." ) - nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) - plt.title("Augmented Graph with simple layout") - plt.axis("off") - plt.show() - elif plot_type == "3d": - try: - from plotly import graph_objects as go - except ImportError: - raise ImportError( - "Plotly is required for 3D plotting. Please install it using:\n\n" - " pip install plotly" - ) - - # Compute 3D coordinates for atoms - AllChem.EmbedMolecule(mol) - conf = mol.GetConformer() - - atom_pos = {} - for atom in mol.GetAtoms(): - idx = atom.GetIdx() - pos = conf.GetAtomPosition(idx) - atom_pos[idx] = (pos.x, pos.y, 0) # pos.z - - # Generate 3D layout for FG and Graph nodes using spring layout - fg_pos_3d = nx.spring_layout(G.subgraph(fg_ids), seed=42, dim=3) - graph_pos_3d = nx.spring_layout(G.subgraph(graph_ids), seed=123, dim=3) - - # Offset to avoid overlap with atom layer - max_z = 0 # max(z for _, (_, _, z) in atom_pos.items()) if atom_pos else 0 - fg_pos = {k: (x, y, z + max_z + 2) for k, (x, y, z) in fg_pos_3d.items()} - graph_pos = {k: (x, y, z + max_z + 4) for k, (x, y, z) in graph_pos_3d.items()} - pos = {**atom_pos, **fg_pos, **graph_pos} - - # Group edges by type - edge_type_to_edges = { - WITHIN_ATOMS_EDGE: [], - ATOM_FG_EDGE: [], - WITHIN_FG_EDGE: [], - FG_GRAPHNODE_EDGE: [], - } - - for src, tgt, data in edges: - edge_type_to_edges[data["type"]].append((src, tgt)) - - # Create edge traces - edge_traces = [] - for edge_type, edge_list in edge_type_to_edges.items(): - xs, ys, zs = [], [], [] - for src, tgt in edge_list: - x0, y0, z0 = pos[src] - x1, y1, z1 = pos[tgt] - xs += [x0, x1, None] - ys += [y0, y1, None] - zs += [z0, z1, None] - - trace = go.Scatter3d( - x=xs, - y=ys, - z=zs, - mode="lines", - line=dict(color=edge_color_map[edge_type], width=4), - name=edge_type, - hoverinfo="none", - ) - edge_traces.append(trace) - - # Node trace - - node_trace = go.Scatter3d( - x=[pos[n][0] for n in G.nodes], - y=[pos[n][1] for n in G.nodes], - z=[pos[n][2] for n in G.nodes], - mode="markers+text", - marker=dict( - size=8, - color=[ - {"atom": "#9ecae1", "fg": "#fdae6b", "graph": "#d62728"}[ - node_type_map[n] - ] - for n in G.nodes - ], - opacity=0.9, - ), - text=[node_labels[n] for n in G.nodes], - textposition="top center", - hoverinfo="text", + AllChem.EmbedMolecule(mol) + conf = mol.GetConformer() + + atom_pos = { + atom.GetIdx(): (pos.x, pos.y, 0) + for atom in mol.GetAtoms() + for pos in [conf.GetAtomPosition(atom.GetIdx())] + } + + fg_pos_3d = nx.spring_layout(G.subgraph(fg_ids), seed=42, dim=3) + graph_pos_3d = nx.spring_layout(G.subgraph(graph_ids), seed=123, dim=3) + fg_pos = {k: (x, y, z + 2) for k, (x, y, z) in fg_pos_3d.items()} + graph_pos = {k: (x, y, z + 4) for k, (x, y, z) in graph_pos_3d.items()} + + pos = {**atom_pos, **fg_pos, **graph_pos} + + edge_type_to_edges = { + WITHIN_ATOMS_EDGE: [], + ATOM_FG_EDGE: [], + WITHIN_FG_EDGE: [], + FG_GRAPHNODE_EDGE: [], + } + for src, tgt, data in edges: + edge_type_to_edges[data["type"]].append((src, tgt)) + + edge_traces = [] + for edge_type, edge_list in edge_type_to_edges.items(): + xs, ys, zs = [], [], [] + for src, tgt in edge_list: + x0, y0, z0 = pos[src] + x1, y1, z1 = pos[tgt] + xs += [x0, x1, None] + ys += [y0, y1, None] + zs += [z0, z1, None] + + trace = go.Scatter3d( + x=xs, + y=ys, + z=zs, + mode="lines", + line=dict(color=EDGE_COLOR_MAP[edge_type], width=4), + name=edge_type, + hoverinfo="none", ) + edge_traces.append(trace) + + node_trace = go.Scatter3d( + x=[pos[n][0] for n in G.nodes], + y=[pos[n][1] for n in G.nodes], + z=[pos[n][2] for n in G.nodes], + mode="markers+text", + marker=dict( + size=8, + color=[NODE_COLOR_MAP[node_type_map[n]] for n in G.nodes], + opacity=0.9, + ), + text=[node_labels[n] for n in G.nodes], + textposition="top center", + hoverinfo="text", + ) + + fig = go.Figure(data=edge_traces + [node_trace]) + fig.update_layout( + title="3D Augmented Molecule Graph", + showlegend=True, + scene=dict( + xaxis=dict(visible=False), + yaxis=dict(visible=False), + zaxis=dict(visible=False), + ), + margin=dict(l=0, r=0, b=0, t=40), + ) + fig.show() - # Combine all traces and plot - fig = go.Figure(data=edge_traces + [node_trace]) - fig.update_layout( - title="3D Augmented Molecule Graph", - showlegend=True, - scene=dict( - xaxis=dict(visible=False), - yaxis=dict(visible=False), - zaxis=dict(visible=False), - ), - margin=dict(l=0, r=0, b=0, t=40), + +def plot_augmented_graph( + edge_index, augmented_graph_nodes, augmented_graph_edges, mol, plot_type +): + ( + G, + node_labels, + node_colors, + node_type_map, + edges, + edge_colors, + atom_ids, + fg_ids, + graph_ids, + ) = _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges) + + if plot_type == "h": + _draw_hierarchy( + G, mol, node_labels, node_type_map, edges, edge_colors, fg_ids, graph_ids ) - fig.show() + elif plot_type == "simple": + _draw_simple(G, node_labels, node_colors) + elif plot_type == "3d": + _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids) else: - raise Exception("Unknown plot type") + raise ValueError(f"Unknown plot type: {plot_type}") class Main: From c4e1cbfd962d8f6b9ffa8848006196b72d1d313c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 30 Apr 2025 23:07:00 +0200 Subject: [PATCH 042/224] restructure and add docstring, typehints --- .../utils/visualize_augmented_molecule.py | 243 ++++++++++++------ 1 file changed, 162 insertions(+), 81 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 174719e..ad31308 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -1,7 +1,9 @@ import matplotlib import matplotlib.pyplot as plt import networkx as nx +import torch from jsonargparse import CLI +from rdkit import Chem from rdkit.Chem import AllChem from chebai_graph.preprocessing.properties.constants import * @@ -23,38 +25,53 @@ } -def _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): +def _create_graph( + edge_index: torch.Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict +) -> nx.Graph: + """ + Create a NetworkX graph from augmented molecular information. + + Args: + edge_index (torch.Tensor): Tensor of shape (2, num_edges) with source and target indices. + augmented_graph_nodes (dict): Dictionary of node attributes grouped by type ('atom_nodes', 'fg_nodes', etc.). + augmented_graph_edges (dict): Dictionary of edges grouped by predefined edge type. + + Returns: + nx.Graph: Constructed NetworkX graph with annotated nodes and edges. + """ G = nx.Graph() - node_labels = {} - node_colors = [] - node_type_map = {} + # Add atom nodes atom_nodes = augmented_graph_nodes["atom_nodes"] - atom_ids = [] for atom in atom_nodes.GetAtoms(): idx = atom.GetIdx() - G.add_node(idx) - node_labels[idx] = atom.GetSymbol() - node_colors.append(NODE_COLOR_MAP["atom"]) - node_type_map[idx] = "atom" - atom_ids.append(idx) + G.add_node( + idx, + node_name=atom.GetSymbol(), + node_type="atom", + node_color=NODE_COLOR_MAP["atom"], + ) + # Add functional group (FG) nodes fg_nodes = augmented_graph_nodes["fg_nodes"] - fg_ids = [] for fg_idx, fg_props in fg_nodes.items(): - G.add_node(fg_idx) - node_labels[fg_idx] = f"FG:{fg_props['FG']}" - node_colors.append(NODE_COLOR_MAP["fg"]) - node_type_map[fg_idx] = "fg" - fg_ids.append(fg_idx) + G.add_node( + fg_idx, + node_name=f"FG:{fg_props['FG']}", + node_type="fg", + node_color=NODE_COLOR_MAP["fg"], + ) + # Add special graph node graph_node_idx = augmented_graph_nodes["num_nodes"] - G.add_node(graph_node_idx) - node_labels[graph_node_idx] = "Graph Node" - node_colors.append(NODE_COLOR_MAP["graph"]) - node_type_map[graph_node_idx] = "graph" - graph_ids = [graph_node_idx] + G.add_node( + graph_node_idx, + node_name="Graph Node", + node_type="graph", + node_color=NODE_COLOR_MAP["graph"], + ) + # Decode edge types and add edges with proper color and type src_nodes, tgt_nodes = edge_index.tolist() with_atom_edges = { f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" @@ -64,8 +81,6 @@ def _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) - edges = [] - edge_colors = [] for src, tgt in zip(src_nodes, tgt_nodes): undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} if undirected_edge_set & with_atom_edges: @@ -78,65 +93,95 @@ def _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges): edge_type = FG_GRAPHNODE_EDGE else: raise Exception("Unexpected edge type") - edges.append((src, tgt, {"type": edge_type})) - edge_colors.append(EDGE_COLOR_MAP[edge_type]) + G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) + + return G + + +def _get_subgraph_by_node_type(G: nx.Graph, node_type: str) -> nx.Graph: + """ + Extract a subgraph containing only nodes of the given type. + + Args: + G (nx.Graph): Full graph with node_type attributes. + node_type (str): Type of node to extract ('atom', 'fg', or 'graph'). + + Returns: + nx.Graph: Subgraph with selected node type. + """ + selected_nodes = [ + n for n, attr in G.nodes(data=True) if attr.get("node_type") == node_type + ] + return G.subgraph(selected_nodes).copy() - G.add_edges_from(edges) - return ( - G, - node_labels, - node_colors, - node_type_map, - edges, - edge_colors, - atom_ids, - fg_ids, - graph_ids, - ) +def _draw_hierarchy(G: nx.Graph, mol: Chem.Mol) -> None: + """ + Draw a hierarchical layout combining RDKit 2D coordinates for atoms and spring layout for FG/graph nodes. -def _draw_hierarchy( - G, mol, node_labels, node_type_map, edges, edge_colors, fg_ids, graph_ids -): + Args: + G (nx.Graph): Augmented molecular graph. + mol (Chem.Mol): RDKit molecule object with atom layout. + """ AllChem.Compute2DCoords(mol) - atom_pos, max_atom_pos_y = {}, 0 + + # Get 2D positions from RDKit + atom_pos = {} + max_atom_pos_y = 0 for atom in mol.GetAtoms(): idx = atom.GetIdx() pos = mol.GetConformer().GetAtomPosition(idx) atom_pos[idx] = (pos.x, pos.y) max_atom_pos_y = max(max_atom_pos_y, pos.y) + # Position FG nodes above atoms + fg_graph = _get_subgraph_by_node_type(G, "fg") fg_pos = { node: (x, y + max_atom_pos_y + 2) - for node, (x, y) in nx.spring_layout(G.subgraph(fg_ids), seed=42).items() + for node, (x, y) in nx.spring_layout(fg_graph, seed=42).items() } + + # Position the graph node further above + graph_node_graph = _get_subgraph_by_node_type(G, "graph") graph_pos = { node: (x, y + max_atom_pos_y + 3) - for node, (x, y) in nx.spring_layout(G.subgraph(graph_ids), seed=123).items() + for node, (x, y) in nx.spring_layout(graph_node_graph, seed=123).items() } + # Merge all positions pos = {**atom_pos, **fg_pos, **graph_pos} - node_colors = [NODE_COLOR_MAP[node_type_map[n]] for n in G.nodes] + node_colors = [G.nodes[n]["node_color"] for n in G.nodes] + node_labels = {n: G.nodes[n]["node_name"] for n in G.nodes} + edge_colors = [G.edges[e]["edge_color"] for e in G.edges] plt.figure(figsize=(10, 8)) nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=600) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) - nx.draw_networkx_edges(G, pos, edgelist=edges, width=2, edge_color=edge_colors) + nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=2) plt.title("Augmented Graph with RDKit Atom Layout + FG/Graph Clusters") plt.axis("off") plt.show() -def _draw_simple(G, node_labels, node_colors): +def _draw_simple(G: nx.Graph) -> None: + """ + Draw the graph using a simple spring layout. + + Args: + G (nx.Graph): Augmented molecular graph. + """ plt.figure(figsize=(10, 8)) pos = nx.spring_layout(G, seed=42) + node_colors = [G.nodes[n]["node_color"] for n in G.nodes] + node_labels = {n: G.nodes[n]["node_name"] for n in G.nodes} + edge_colors = [G.edges[e]["edge_color"] for e in G.edges] nx.draw( G, pos, with_labels=False, node_color=node_colors, node_size=600, - edge_color=[EDGE_COLOR_MAP[data["type"]] for _, _, data in G.edges(data=True)], + edge_color=edge_colors, width=2, ) nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10) @@ -145,7 +190,17 @@ def _draw_simple(G, node_labels, node_colors): plt.show() -def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): +def _draw_3d(G: nx.Graph, mol: Chem.Mol) -> None: + """ + Visualize the graph in 3D using Plotly. + + Args: + G (nx.Graph): Augmented molecular graph. + mol (Chem.Mol): RDKit molecule object for 3D coordinates. + + Raises: + ImportError: If Plotly is not installed. + """ try: from plotly import graph_objects as go except ImportError: @@ -153,6 +208,7 @@ def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): "Plotly is required for 3D plotting. Install it with `pip install plotly`." ) + # Generate 3D coordinates for atoms AllChem.EmbedMolecule(mol) conf = mol.GetConformer() @@ -162,21 +218,26 @@ def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): for pos in [conf.GetAtomPosition(atom.GetIdx())] } - fg_pos_3d = nx.spring_layout(G.subgraph(fg_ids), seed=42, dim=3) - graph_pos_3d = nx.spring_layout(G.subgraph(graph_ids), seed=123, dim=3) + # Generate 3D layout for FG and graph nodes + fg_graph = _get_subgraph_by_node_type(G, "fg") + fg_pos_3d = nx.spring_layout(fg_graph, seed=42, dim=3) fg_pos = {k: (x, y, z + 2) for k, (x, y, z) in fg_pos_3d.items()} + + graph_node_graph = _get_subgraph_by_node_type(G, "graph") + graph_pos_3d = nx.spring_layout(graph_node_graph, seed=123, dim=3) graph_pos = {k: (x, y, z + 4) for k, (x, y, z) in graph_pos_3d.items()} pos = {**atom_pos, **fg_pos, **graph_pos} + # Collect edges by type edge_type_to_edges = { WITHIN_ATOMS_EDGE: [], ATOM_FG_EDGE: [], WITHIN_FG_EDGE: [], FG_GRAPHNODE_EDGE: [], } - for src, tgt, data in edges: - edge_type_to_edges[data["type"]].append((src, tgt)) + for src, tgt, data in G.edges(data=True): + edge_type_to_edges[data["edge_type"]].append((src, tgt)) edge_traces = [] for edge_type, edge_list in edge_type_to_edges.items(): @@ -199,17 +260,21 @@ def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): ) edge_traces.append(trace) + # Collect node attributes for visualization + pos_x, pos_y, pos_z, node_colors, node_names = zip( + *[ + (pos[n][0], pos[n][1], pos[n][2], attr["node_color"], attr["node_name"]) + for n, attr in G.nodes(data=True) + ] + ) + node_trace = go.Scatter3d( - x=[pos[n][0] for n in G.nodes], - y=[pos[n][1] for n in G.nodes], - z=[pos[n][2] for n in G.nodes], + x=pos_x, + y=pos_y, + z=pos_z, mode="markers+text", - marker=dict( - size=8, - color=[NODE_COLOR_MAP[node_type_map[n]] for n in G.nodes], - opacity=0.9, - ), - text=[node_labels[n] for n in G.nodes], + marker=dict(size=8, color=node_colors, opacity=0.9), + text=node_names, textposition="top center", hoverinfo="text", ) @@ -229,37 +294,53 @@ def _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids): def plot_augmented_graph( - edge_index, augmented_graph_nodes, augmented_graph_edges, mol, plot_type -): - ( - G, - node_labels, - node_colors, - node_type_map, - edges, - edge_colors, - atom_ids, - fg_ids, - graph_ids, - ) = _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges) + edge_index: torch.Tensor, + augmented_graph_nodes: dict, + augmented_graph_edges: dict, + mol: Chem.Mol, + plot_type: str, +) -> None: + """ + Main plotting function to visualize the augmented graph. + + Args: + edge_index (torch.Tensor): Edge indices tensor (2, num_edges). + augmented_graph_nodes (dict): Node metadata. + augmented_graph_edges (dict): Edge metadata. + mol (Chem.Mol): RDKit molecule object. + plot_type (str): One of ["simple", "h", "3d"]. + """ + G = _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges) if plot_type == "h": - _draw_hierarchy( - G, mol, node_labels, node_type_map, edges, edge_colors, fg_ids, graph_ids - ) + _draw_hierarchy(G, mol) elif plot_type == "simple": - _draw_simple(G, node_labels, node_colors) + _draw_simple(G) elif plot_type == "3d": - _draw_3d(G, mol, node_labels, node_type_map, edges, fg_ids, graph_ids) + _draw_3d(G, mol) else: raise ValueError(f"Unknown plot type: {plot_type}") class Main: + """ + Command-line wrapper class for plotting augmented molecular graphs. + """ + def __init__(self): self._fg_reader = GraphFGAugmentorReader() - def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple"): + def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> None: + """ + Plot an augmented molecular graph from SMILES. + + Args: + smiles (str): SMILES string to parse. + plot_type (str): Type of plot. One of ['simple', 'h', '3d']. + - simple : 2D graph with all nodes on same plane + - h: Hierarchical 2D-graph with separate plane for each node type + - 3d: Hierarchical 3D-graph + """ mol = self._fg_reader._smiles_to_mol(smiles) # noqa edge_index, augmented_nodes, augmented_edges = self._fg_reader._augment_graph( mol @@ -270,5 +351,5 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple"): if __name__ == "__main__": - # use:- visualize_augmented_molecule.py plot --smiles="OC(=O)c1ccccc1O" + # Example: python visualize_augmented_molecule.py plot --smiles="OC(=O)c1ccccc1O" --plot_type="h" CLI(Main) From 468bfc8f6b4b0a42eeeac26634c1bb96a092d8f0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 11:57:17 +0200 Subject: [PATCH 043/224] proper namespace to avoid future conflicts --- .../preprocessing/properties/__init__.py | 45 ++++++++++++++++++- chebai_graph/preprocessing/reader.py | 10 ++--- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 853d02c..45a5b65 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -2,7 +2,48 @@ # but it has to be imported after properties module, to avoid circular imports # This is because augmented properties module imports from properties module # isort: off -from .properties import * -from .augmented_properties import * +from .properties import ( + MolecularProperty, + AtomType, + NumAtomBonds, + AtomCharge, + AtomChirality, + AtomHybridization, + AtomNumHs, + AtomAromaticity, + BondAromaticity, + BondType, + BondInRing, + MoleculeNumRings, + RDKit2DNormalized, +) + +from .augmented_properties import ( + AtomNodeLevel, + AtomFunctionalGroup, + AtomRingSize, + BondLevel, +) # isort: on + +__all__ = [ + "MolecularProperty", + "AtomType", + "NumAtomBonds", + "AtomCharge", + "AtomChirality", + "AtomHybridization", + "AtomNumHs", + "AtomAromaticity", + "BondAromaticity", + "BondType", + "BondInRing", + "MoleculeNumRings", + "RDKit2DNormalized", + # -------- Augmented Molecular Properties -------- + "AtomNodeLevel", + "AtomFunctionalGroup", + "AtomRingSize", + "BondLevel", +] diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index fbc96ec..cb0fb01 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -11,13 +11,13 @@ from torch_geometric.data import Data as GeomData from torch_geometric.utils import from_networkx -import chebai_graph.preprocessing.properties as properties from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.rule_based import ( detect_functional_group, get_structure, set_atom_map_num, ) +from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties.constants import * @@ -76,9 +76,7 @@ def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property( - self, smiles: str, property: properties.MolecularProperty - ) -> Optional[List]: + def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: mol = self._smiles_to_mol(smiles) if mol is None: return None @@ -352,9 +350,7 @@ def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property( - self, smiles: str, property: properties.MolecularProperty - ) -> Optional[List]: + def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: mol = self._smiles_to_mol(smiles) if mol is None: return None From f9a0da272ce1f6aeb5a6ff4b85c292b51f15a637 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 12:06:17 +0200 Subject: [PATCH 044/224] seperate file for augmented reader --- chebai_graph/preprocessing/reader/__init__.py | 9 ++ .../{reader.py => reader/augmented_reader.py} | 134 +----------------- chebai_graph/preprocessing/reader/reader.py | 131 +++++++++++++++++ 3 files changed, 146 insertions(+), 128 deletions(-) create mode 100644 chebai_graph/preprocessing/reader/__init__.py rename chebai_graph/preprocessing/{reader.py => reader/augmented_reader.py} (68%) create mode 100644 chebai_graph/preprocessing/reader/reader.py diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py new file mode 100644 index 0000000..737ee92 --- /dev/null +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -0,0 +1,9 @@ +from .augmented_reader import GraphFGAugmentorReader, RuleBasedFGReader +from .reader import GraphPropertyReader, GraphReader + +__all__ = [ + "GraphReader", + "GraphPropertyReader", + "GraphFGAugmentorReader", + "RuleBasedFGReader", +] diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py similarity index 68% rename from chebai_graph/preprocessing/reader.py rename to chebai_graph/preprocessing/reader/augmented_reader.py index cb0fb01..ca5c7c8 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,15 +1,10 @@ -import os from typing import List, Optional -import chebai.preprocessing.reader as dr -import networkx as nx -import pysmiles as ps -import rdkit.Chem as Chem -import torch +from chebai.preprocessing.reader import ChemDataReader from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn -from rdkit.Chem import Mol +from rdkit import Chem from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx +from wandb.integration.torch.wandb_torch import torch from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.rule_based import ( @@ -21,124 +16,7 @@ from chebai_graph.preprocessing.properties.constants import * -class GraphPropertyReader(dr.ChemDataReader): - COLLATOR = GraphCollator - - def __init__( - self, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.failed_counter = 0 - self.mol_object_buffer = {} - - @classmethod - def name(cls): - return "graph_properties" - - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: - """Load smiles into rdkit, store object in buffer""" - if smiles in self.mol_object_buffer: - return self.mol_object_buffer[smiles] - - mol = Chem.MolFromSmiles(smiles) - if mol is None: - rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") - self.failed_counter += 1 - else: - try: - Chem.SanitizeMol(mol) - except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") - self.failed_counter += 1 - self.mol_object_buffer[smiles] = mol - return mol - - def _read_data(self, raw_data): - mol = self._smiles_to_mol(raw_data) - if mol is None: - return None - - x = torch.zeros((mol.GetNumAtoms(), 0)) - - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) - - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) - - def on_finish(self): - rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") - self.mol_object_buffer = {} - - def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: - mol = self._smiles_to_mol(smiles) - if mol is None: - return None - return property.get_property_value(mol) - - -class GraphReader(dr.ChemDataReader): - """Reads each atom as one token (atom symbol + charge), reads bond order as edge attribute. - Creates nx Graph from SMILES.""" - - COLLATOR = GraphCollator - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dirname = os.path.dirname(__file__) - - @classmethod - def name(cls): - return "graph" - - def _read_data(self, raw_data) -> Optional[GeomData]: - # raw_data is a SMILES string - try: - mol = ps.read_smiles(raw_data) - except ValueError: - return None - assert isinstance(mol, nx.Graph) - d = {} - de = {} - for node in mol.nodes: - n = mol.nodes[node] - try: - m = n["element"] - charge = n["charge"] - if charge: - if charge > 0: - m += "+" - else: - m += "-" - charge *= -1 - if charge > 1: - m += str(charge) - m = f"[{m}]" - except KeyError: - m = "*" - d[node] = self._get_token_index(m) - for attr in list(mol.nodes[node].keys()): - del mol.nodes[node][attr] - for edge in mol.edges: - de[edge] = mol.edges[edge]["order"] - for attr in list(mol.edges[edge].keys()): - del mol.edges[edge][attr] - nx.set_node_attributes(mol, d, "x") - nx.set_edge_attributes(mol, de, "edge_attr") - data = from_networkx(mol) - return data - - def collate(self, list_of_tuples): - return self.collator(list_of_tuples) - - -class GraphFGAugmentorReader(dr.ChemDataReader): +class GraphFGAugmentorReader(ChemDataReader): COLLATOR = GraphCollator def __init__( @@ -189,7 +67,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: self.failed_counter += 1 return mol - def _augment_graph(self, mol: Mol): + def _augment_graph(self, mol: Chem.Mol): edge_index = torch.tensor( [ [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], @@ -362,7 +240,7 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[Li return property.get_property_value(mol) -class RuleBasedFGReader(dr.ChemDataReader): +class RuleBasedFGReader(ChemDataReader): @classmethod def name(cls) -> str: diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py new file mode 100644 index 0000000..bdfa172 --- /dev/null +++ b/chebai_graph/preprocessing/reader/reader.py @@ -0,0 +1,131 @@ +import os +from typing import List, Optional + +import chebai.preprocessing.reader as dr +import networkx as nx +import pysmiles as ps +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn +from torch_geometric.data import Data as GeomData +from torch_geometric.utils import from_networkx + +from chebai_graph.preprocessing.collate import GraphCollator +from chebai_graph.preprocessing.properties import MolecularProperty + + +class GraphPropertyReader(dr.ChemDataReader): + COLLATOR = GraphCollator + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.failed_counter = 0 + self.mol_object_buffer = {} + + @classmethod + def name(cls): + return "graph_properties" + + def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: + """Load smiles into rdkit, store object in buffer""" + if smiles in self.mol_object_buffer: + return self.mol_object_buffer[smiles] + + mol = Chem.MolFromSmiles(smiles) + if mol is None: + rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") + self.failed_counter += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") + self.failed_counter += 1 + self.mol_object_buffer[smiles] = mol + return mol + + def _read_data(self, raw_data): + mol = self._smiles_to_mol(raw_data) + if mol is None: + return None + + x = torch.zeros((mol.GetNumAtoms(), 0)) + + edge_attr = torch.zeros((mol.GetNumBonds(), 0)) + + edge_index = torch.tensor( + [ + [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], + [bond.GetEndAtomIdx() for bond in mol.GetBonds()], + ] + ) + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def on_finish(self): + rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") + self.mol_object_buffer = {} + + def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + return property.get_property_value(mol) + + +class GraphReader(dr.ChemDataReader): + """Reads each atom as one token (atom symbol + charge), reads bond order as edge attribute. + Creates nx Graph from SMILES.""" + + COLLATOR = GraphCollator + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dirname = os.path.dirname(__file__) + + @classmethod + def name(cls): + return "graph" + + def _read_data(self, raw_data) -> Optional[GeomData]: + # raw_data is a SMILES string + try: + mol = ps.read_smiles(raw_data) + except ValueError: + return None + assert isinstance(mol, nx.Graph) + d = {} + de = {} + for node in mol.nodes: + n = mol.nodes[node] + try: + m = n["element"] + charge = n["charge"] + if charge: + if charge > 0: + m += "+" + else: + m += "-" + charge *= -1 + if charge > 1: + m += str(charge) + m = f"[{m}]" + except KeyError: + m = "*" + d[node] = self._get_token_index(m) + for attr in list(mol.nodes[node].keys()): + del mol.nodes[node][attr] + for edge in mol.edges: + de[edge] = mol.edges[edge]["order"] + for attr in list(mol.edges[edge].keys()): + del mol.edges[edge][attr] + nx.set_node_attributes(mol, d, "x") + nx.set_edge_attributes(mol, de, "edge_attr") + data = from_networkx(mol) + return data + + def collate(self, list_of_tuples): + return self.collator(list_of_tuples) From 5d822295dbba3ab6ed48019c178744c939b35e19 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 12:21:01 +0200 Subject: [PATCH 045/224] abstract base class for augmentor readers --- .../preprocessing/reader/augmented_reader.py | 77 +++++++++++-------- 1 file changed, 47 insertions(+), 30 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index ca5c7c8..a7f035e 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple from chebai.preprocessing.reader import ChemDataReader from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn @@ -16,7 +17,7 @@ from chebai_graph.preprocessing.properties.constants import * -class GraphFGAugmentorReader(ChemDataReader): +class _AugmentorReader(ChemDataReader, ABC): COLLATOR = GraphCollator def __init__( @@ -28,6 +29,50 @@ def __init__( self.failed_counter = 0 self.mol_object_buffer = {} + @classmethod + @abstractmethod + def name(cls) -> str: + pass + + @abstractmethod + def _get_augmented_molecule(self, smile: str) -> Tuple[Dict, torch.Tensor]: + pass + + @abstractmethod + def _read_data(self, raw_data: str) -> List[int]: + pass + + def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") + self.failed_counter += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, Error {e}") + self.failed_counter += 1 + return mol + + def on_finish(self): + rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") + self.mol_object_buffer = {} + + def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + + augmented_mol, _ = self._get_augmented_molecule(smiles) + return property.get_property_value(mol) + + +class GraphFGAugmentorReader(_AugmentorReader): + @classmethod def name(cls): return "graph_fg_augmentor" @@ -54,19 +99,6 @@ def _get_augmented_molecule(self, smiles): return augmented_mol, edge_index - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: - mol = Chem.MolFromSmiles(smiles) - if mol is None: - rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") - self.failed_counter += 1 - else: - try: - Chem.SanitizeMol(mol) - except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, Error {e}") - self.failed_counter += 1 - return mol - def _augment_graph(self, mol: Chem.Mol): edge_index = torch.tensor( [ @@ -224,21 +256,6 @@ def _get_ring_size(self, atom): else: return 0 - def on_finish(self): - rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") - self.mol_object_buffer = {} - - def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: - mol = self._smiles_to_mol(smiles) - if mol is None: - return None - - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - - augmented_mol, _ = self._get_augmented_molecule(smiles) - return property.get_property_value(mol) - class RuleBasedFGReader(ChemDataReader): From 4345b3d3484e34dd9f41e19113b2f9e3b0b16e21 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 13:17:49 +0200 Subject: [PATCH 046/224] for 3d plot hover info to be id of the node --- .../utils/visualize_augmented_molecule.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index ad31308..866a284 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -1,10 +1,9 @@ import matplotlib import matplotlib.pyplot as plt import networkx as nx -import torch from jsonargparse import CLI -from rdkit import Chem -from rdkit.Chem import AllChem +from rdkit.Chem import AllChem, Mol +from torch import Tensor from chebai_graph.preprocessing.properties.constants import * from chebai_graph.preprocessing.reader import GraphFGAugmentorReader @@ -26,7 +25,7 @@ def _create_graph( - edge_index: torch.Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict + edge_index: Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict ) -> nx.Graph: """ Create a NetworkX graph from augmented molecular information. @@ -115,7 +114,7 @@ def _get_subgraph_by_node_type(G: nx.Graph, node_type: str) -> nx.Graph: return G.subgraph(selected_nodes).copy() -def _draw_hierarchy(G: nx.Graph, mol: Chem.Mol) -> None: +def _draw_hierarchy(G: nx.Graph, mol: Mol) -> None: """ Draw a hierarchical layout combining RDKit 2D coordinates for atoms and spring layout for FG/graph nodes. @@ -190,7 +189,7 @@ def _draw_simple(G: nx.Graph) -> None: plt.show() -def _draw_3d(G: nx.Graph, mol: Chem.Mol) -> None: +def _draw_3d(G: nx.Graph, mol: Mol) -> None: """ Visualize the graph in 3D using Plotly. @@ -261,9 +260,9 @@ def _draw_3d(G: nx.Graph, mol: Chem.Mol) -> None: edge_traces.append(trace) # Collect node attributes for visualization - pos_x, pos_y, pos_z, node_colors, node_names = zip( + pos_x, pos_y, pos_z, node_colors, node_names, node_ids = zip( *[ - (pos[n][0], pos[n][1], pos[n][2], attr["node_color"], attr["node_name"]) + (pos[n][0], pos[n][1], pos[n][2], attr["node_color"], attr["node_name"], n) for n, attr in G.nodes(data=True) ] ) @@ -276,6 +275,7 @@ def _draw_3d(G: nx.Graph, mol: Chem.Mol) -> None: marker=dict(size=8, color=node_colors, opacity=0.9), text=node_names, textposition="top center", + hovertext=node_ids, hoverinfo="text", ) @@ -294,10 +294,10 @@ def _draw_3d(G: nx.Graph, mol: Chem.Mol) -> None: def plot_augmented_graph( - edge_index: torch.Tensor, + edge_index: Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict, - mol: Chem.Mol, + mol: Mol, plot_type: str, ) -> None: """ From ad864c50d59fef539f524fc6e4426cd0e731c00b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 16:39:50 +0200 Subject: [PATCH 047/224] logic change to detect FG related to rings --- .../preprocessing/fg_detection/rule_based.py | 9 ++-- .../preprocessing/reader/augmented_reader.py | 43 ++++++++++--------- .../utils/visualize_augmented_molecule.py | 5 +++ 3 files changed, 32 insertions(+), 25 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/rule_based.py b/chebai_graph/preprocessing/fg_detection/rule_based.py index ab4fe30..186ff70 100644 --- a/chebai_graph/preprocessing/fg_detection/rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/rule_based.py @@ -1927,10 +1927,9 @@ def get_structure(mol): if __name__ == "__main__": from rdkit.Chem import MolFromSmiles as s2m - SMILES = ( - "CCOc1c(OC)cccc1[C@@H]1C(C(=O)OCCOC)=C(C)N=c2s/c(=C/c3cccc(OCC#N)c3)c(=O)n21" - ) - mol = s2m(SMILES) - # set_atom_map_num(mol) + smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin, CHEBI:15365 - acetylsalicylic acid + mol = s2m(smiles) + set_atom_map_num(mol) + detect_functional_group(mol) get_structure(mol) print(m2s(mol)) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index a7f035e..f474afc 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple +import torch from chebai.preprocessing.reader import ChemDataReader from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from rdkit import Chem from torch_geometric.data import Data as GeomData -from wandb.integration.torch.wandb_torch import torch from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.rule_based import ( @@ -140,29 +140,17 @@ def _augment_graph(self, mol: Chem.Mol): fg_atom_edges[f"{num_of_nodes}_{atom}"] = {EDGE_LEVEL: ATOM_FG_EDGE} num_of_edges += 1 - fg_set = { - mol.GetAtomWithIdx(atom_idx).GetProp("FG") + ring_fg = { + mol.GetAtomWithIdx(atom_idx).GetProp("RING") for atom_idx in structure[fg]["atom"] - if mol.GetAtomWithIdx(atom_idx).GetProp("FG") + if mol.GetAtomWithIdx(atom_idx).GetProp("RING") } - if len(fg_set) > 1: - raise Exception("connected atoms should belong to only one fg") - - elif len(fg_set) == 0: - ring_sizes = set() - for atom_idx in structure[fg]["atom"]: - atom = mol.GetAtomWithIdx(atom_idx) - ring_size_prop = atom.GetProp("RING") - if not ring_size_prop: - raise Exception("All atoms should have ring size") - ring_sizes.add(int(ring_size_prop)) - atom.SetProp("FG", f"RING_{ring_size_prop}") - - # TODO: Incase error is raised check logic for fused rings - assert len(ring_sizes) == 1, "all atoms should have one ring size" - ring_size = list(ring_sizes)[0] + if len(ring_fg) > 1: + raise Exception("connected atom rings should have only one ring size") + elif len(ring_fg) == 1: + ring_size = next(iter(ring_fg)) fg_nodes[num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, "FG": f"RING_{ring_size}", @@ -170,11 +158,26 @@ def _augment_graph(self, mol: Chem.Mol): } else: + + fg_set = { + mol.GetAtomWithIdx(atom_idx).GetProp("FG") + for atom_idx in structure[fg]["atom"] + } + if "" in fg_set: + raise Exception( + "All connected atoms have a Functional Group assigned" + ) + elif len(fg_set) > 1: + raise Exception( + "All connected atoms should belong to only one Functional Group or should have" + ) + any_atom = None for atom_idx in structure[fg]["atom"]: atom = mol.GetAtomWithIdx(atom_idx) if atom.GetProp("FG"): any_atom = atom + assert any_atom is not None, "Need a FG" fg_nodes[num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 866a284..f29d109 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -352,4 +352,9 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No if __name__ == "__main__": # Example: python visualize_augmented_molecule.py plot --smiles="OC(=O)c1ccccc1O" --plot_type="h" + # Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365, acetylsalicylic acid + # Salicylic acid -> OC(=O)c1ccccc1O ; CHEBI:16914 + # 1-hydroxy-2-naphthoic acid -> OC(=O)c1ccc2ccccc2c1O ; CHEBI:36108 ; Fused Rings + # 3-nitrobenzoic acid -> OC(=O)C1=CC(=CC=C1)[N+]([O-])=O ; CHEBI:231494 ; Ring + Novel atom (Nitrogen) + # nile blue A -> [Cl-].CCN(CC)c1ccc2nc3c(cc(N)c4ccccc34)[o+]c2c1 ; CHEBI:52163 ; Fused rings + Novel atoms CLI(Main) From d0d35f71fed9c23af0abb7aee230410df1742d65 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 19:10:55 +0200 Subject: [PATCH 048/224] restructure reader and add docstrings + typehints --- .../preprocessing/reader/augmented_reader.py | 490 ++++++++++++------ .../utils/visualize_augmented_molecule.py | 16 +- 2 files changed, 347 insertions(+), 159 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index f474afc..2d279b0 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -18,48 +18,109 @@ class _AugmentorReader(ChemDataReader, ABC): + """ + Abstract base class for augmentor readers that extend ChemDataReader. + Handles reading molecular data and augmenting molecules with functional group + information. + + Attributes: + failed_counter (int): Counter for failed SMILES parsing attempts. + mol_object_buffer (dict): Cache for storing augmented molecular objects. + """ + COLLATOR = GraphCollator - def __init__( - self, - *args, - **kwargs, - ): + def __init__(self, *args, **kwargs): + """ + Initializes the augmentor reader and sets up the failure counter and molecule cache. + + Args: + *args: Additional arguments passed to the ChemDataReader. + **kwargs: Additional keyword arguments passed to the ChemDataReader. + """ super().__init__(*args, **kwargs) self.failed_counter = 0 self.mol_object_buffer = {} + self.num_nodes = 0 + self.num_of_edges = 0 @classmethod @abstractmethod def name(cls) -> str: + """ + Returns the name of the augmentor. + + Returns: + str: Name of the augmentor. + """ pass @abstractmethod - def _get_augmented_molecule(self, smile: str) -> Tuple[Dict, torch.Tensor]: + def _create_augmented_graph(self, smile: str) -> Tuple[Dict, torch.Tensor]: + """ + Augments a molecule represented by a SMILES string. + + Args: + smile (str): SMILES string representing the molecule. + + Returns: + Tuple[Dict, torch.Tensor]: Augmented molecule information and corresponding edge index. + """ pass @abstractmethod - def _read_data(self, raw_data: str) -> List[int]: + def _read_data(self, raw_data: str) -> GeomData: + """ + Reads raw data and returns a list of processed data. + + Args: + raw_data (str): Raw data input. + + Returns: + List[int]: Processed data as a list of integers. + """ pass - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: + def _smiles_to_mol(self, smiles: str) -> Optional[Chem.Mol]: + """ + Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. + + Args: + smiles (str): SMILES string representing the molecule. + + Returns: + Optional[Chem.Mol]: RDKit molecule object if conversion is successful, else None. + """ mol = Chem.MolFromSmiles(smiles) if mol is None: - rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") + rank_zero_warn(f"RDKit failed to parse {smiles} (returned None)") self.failed_counter += 1 else: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, Error {e}") + rank_zero_warn(f"RDKit failed at sanitizing {smiles}, Error {e}") self.failed_counter += 1 return mol def on_finish(self): + """ + Finalizes the reading process and logs the number of failed SMILES. + """ rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: + """ + Reads a specific property from a molecule represented by a SMILES string. + + Args: + smiles (str): SMILES string representing the molecule. + property (MolecularProperty): Molecular property object for which the value needs to be extracted. + + Returns: + Optional[List]: Property values if molecule parsing is successful, else None. + """ mol = self._smiles_to_mol(smiles) if mol is None: return None @@ -67,204 +128,333 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[Li if smiles in self.mol_object_buffer: return property.get_property_value(self.mol_object_buffer[smiles]) - augmented_mol, _ = self._get_augmented_molecule(smiles) + augmented_mol, _ = self._create_augmented_graph(smiles) return property.get_property_value(mol) class GraphFGAugmentorReader(_AugmentorReader): + """ + A reader class that augments molecules with artificial functional group (FG) nodes and a graph-level node + to support graph-based molecular learning tasks. + + The FG nodes to connected to its related atoms and graph node is connected to all FG nodes. + """ @classmethod - def name(cls): - return "graph_fg_augmentor" + def name(cls) -> str: + """ + Returns the name identifier of the augmentor. - def _read_data(self, raw_data): - augmented_mol, edge_index = self._get_augmented_molecule(raw_data) + Returns: + str: Name identifier. + """ + return "graph_fg_augmentor" - x = torch.zeros((augmented_mol["nodes"]["num_nodes"], 0)) - edge_attr = torch.zeros((augmented_mol["nodes"]["num_edges"], 0)) + def _read_data(self, smiles: str) -> GeomData | None: + """ + Reads and augments molecular data from a SMILES string. - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + Args: + smiles (str): SMILES representation of the molecule. - def _get_augmented_molecule(self, smiles): + Returns: + GeomData: A PyTorch Geometric Data object with augmented nodes and edges. + """ mol = self._smiles_to_mol(smiles) if mol is None: return None - edge_index, augmented_graph_nodes, augmented_graph_edges = self._augment_graph( - mol + edge_index, augmented_molecule = self._create_augmented_graph(mol) + self.mol_object_buffer[smiles] = augmented_molecule + + num_nodes = augmented_molecule["nodes"]["num_nodes"] + num_edges = augmented_molecule["edges"]["num_edges"] + + # Empty features initialized; node and edge features can be added later + x = torch.zeros((num_nodes, 0)) + edge_attr = torch.zeros((num_edges, 0)) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + + def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: + """ + Generates an augmented graph from a SMILES string. + + Args: + mol (Chem.Mol): A molecule generated by RDKit. + + Returns: + Tuple[dict, torch.Tensor]: Augmented molecule information and edge index. + """ + edge_index, node_info, edge_info = self._augment_graph_structure(mol) + + augmented_molecule = {"nodes": node_info, "edges": edge_info} + + return edge_index, augmented_molecule + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> Tuple[torch.Tensor, dict, dict]: + """ + Constructs the full augmented graph structure from a molecule. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + Tuple[torch.Tensor, dict, dict]: Edge index, node metadata, and edge metadata. + """ + self.num_of_nodes = mol.GetNumAtoms() + self.num_of_edges = mol.GetNumBonds() + + self._annotate_atoms_and_bonds(mol) + atom_edge_index = self._generate_atom_level_edge_index(mol) + + # Create FG-level structure and edges + fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds = ( + self._construct_fg_to_atom_structure(mol) + ) + fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( + structured_fg_map, bonds + ) + fg_graph_edge_index, graph_node, fg_to_graph_edges = ( + self._construct_fg_to_graph_node_structure(structured_fg_map) ) - augmented_mol = {"nodes": augmented_graph_nodes, "edges": augmented_graph_edges} - self.mol_object_buffer[smiles] = augmented_mol + # Merge all edge types + full_edge_index = torch.cat( + [ + atom_edge_index, + torch.tensor(fg_atom_edge_index, dtype=torch.long), + torch.tensor(fg_internal_edge_index, dtype=torch.long), + torch.tensor(fg_graph_edge_index, dtype=torch.long), + ], + dim=1, + ) + + node_info = { + "atom_nodes": mol, + "fg_nodes": fg_nodes, + "graph_node": graph_node, + "num_nodes": self.num_of_nodes, + } + edge_info = { + WITHIN_ATOMS_EDGE: mol, + ATOM_FG_EDGE: atom_fg_edges, + WITHIN_FG_EDGE: internal_fg_edges, + FG_GRAPHNODE_EDGE: fg_to_graph_edges, + "num_edges": self.num_of_edges, + } - return augmented_mol, edge_index + return full_edge_index, node_info, edge_info - def _augment_graph(self, mol: Chem.Mol): + @staticmethod + def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: + """ + Annotates each atom and bond with node and edge with certain properties. + + Args: + mol (Chem.Mol): RDKit molecule. + """ + for atom in mol.GetAtoms(): + atom.SetProp(NODE_LEVEL, ATOM_NODE_LEVEL) + for bond in mol.GetBonds(): + bond.SetProp(EDGE_LEVEL, WITHIN_ATOMS_EDGE) + + @staticmethod + def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: + """ + Generates bidirectional atom-level edge index tensor. + + Args: + mol (Chem.Mol): RDKit molecule. + + Returns: + torch.Tensor: Bidirectional edge index tensor. + """ edge_index = torch.tensor( [ [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], [bond.GetEndAtomIdx() for bond in mol.GetBonds()], ] ) - within_atoms_edge_index = torch.cat([edge_index, edge_index[[1, 0], :]], dim=1) + return torch.cat([edge_index, edge_index[[1, 0], :]], dim=1) - num_of_nodes = mol.GetNumAtoms() - num_of_edges = mol.GetNumBonds() + def _construct_fg_to_atom_structure( + self, mol: Chem.Mol + ) -> Tuple[List[List[int]], dict, dict, dict, list]: + """ + Constructs edges between functional group (FG) nodes and atom nodes. - set_atom_map_num(mol) - detect_functional_group(mol) - - for atom in mol.GetAtoms(): - atom.SetProp(NODE_LEVEL, ATOM_NODE_LEVEL) + Args: + mol (Chem.Mol): RDKit molecule. - for edge in mol.GetBonds(): - edge.SetProp(EDGE_LEVEL, WITHIN_ATOMS_EDGE) + Returns: + Tuple[List[List[int]], dict, dict, dict, list]: + Edge index, FG node info, FG-atom edge attributes, + structured FG mapping, and bond list. + """ + # Rule-based algorithm to detect functional groups + set_atom_map_num(mol) + detect_functional_group(mol) structure, bonds = get_structure(mol) + assert structure is not None, "Failed to detect functional groups." + + fg_atom_edge_index = [[], []] + fg_nodes, atom_fg_edges = {}, {} + structured_fg_map = ( + {} + ) # Contains augmented fg-nodes and connected atoms indices + + for idx, fg_key in enumerate(structure): + structured_fg_map[self.num_of_nodes] = {"atom": structure[fg_key]["atom"]} + + # Build edge index for fg to atom nodes connections + for atom_idx in structure[fg_key]["atom"]: + fg_atom_edge_index[0] += [self.num_of_nodes, atom_idx] + fg_atom_edge_index[1] += [atom_idx, self.num_of_nodes] + atom_fg_edges[f"{self.num_of_nodes}_{atom_idx}"] = { + EDGE_LEVEL: ATOM_FG_EDGE + } + self.num_of_edges += 1 - if not structure: - raise ValueError("") - - # Preprocess the molecular structure to match feature dictionary keys - fg_to_atoms_edge_index = [[], []] - fg_nodes, fg_atom_edges = {}, {} - new_structure = {} - for idx, fg in enumerate(structure): - # new_sm = preprocess_smiles(sm) # Preprocess SMILES to match the feature dictionary - new_structure[num_of_nodes] = { - "atom": structure[fg]["atom"] # Get atom list for fragment - } - for atom in structure[fg]["atom"]: - fg_to_atoms_edge_index[0].extend([num_of_nodes, atom]) - fg_to_atoms_edge_index[1].extend([atom, num_of_nodes]) - fg_atom_edges[f"{num_of_nodes}_{atom}"] = {EDGE_LEVEL: ATOM_FG_EDGE} - num_of_edges += 1 - + # Identify ring vs. functional group type ring_fg = { - mol.GetAtomWithIdx(atom_idx).GetProp("RING") - for atom_idx in structure[fg]["atom"] - if mol.GetAtomWithIdx(atom_idx).GetProp("RING") + mol.GetAtomWithIdx(i).GetProp("RING") + for i in structure[fg_key]["atom"] + if mol.GetAtomWithIdx(i).GetProp("RING") } if len(ring_fg) > 1: - raise Exception("connected atom rings should have only one ring size") + raise ValueError( + "A functional group must not span multiple ring sizes." + ) - elif len(ring_fg) == 1: + if ( + len(ring_fg) == 1 + ): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = next(iter(ring_fg)) - fg_nodes[num_of_nodes] = { + fg_nodes[self.num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, + # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring "FG": f"RING_{ring_size}", "RING": ring_size, } - - else: - + else: # No connected has a ring size which indicates it is simple FG fg_set = { - mol.GetAtomWithIdx(atom_idx).GetProp("FG") - for atom_idx in structure[fg]["atom"] + mol.GetAtomWithIdx(i).GetProp("FG") + for i in structure[fg_key]["atom"] } - if "" in fg_set: - raise Exception( - "All connected atoms have a Functional Group assigned" - ) - elif len(fg_set) > 1: - raise Exception( - "All connected atoms should belong to only one Functional Group or should have" - ) + if "" in fg_set or len(fg_set) > 1: + raise ValueError("Invalid functional group assignment to atoms.") - any_atom = None - for atom_idx in structure[fg]["atom"]: + for atom_idx in structure[fg_key]["atom"]: atom = mol.GetAtomWithIdx(atom_idx) if atom.GetProp("FG"): - any_atom = atom - assert any_atom is not None, "Need a FG" - - fg_nodes[num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, - "FG": any_atom.GetProp("FG"), - "RING": any_atom.GetProp("RING"), - } + fg_nodes[self.num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": atom.GetProp("FG"), + "RING": atom.GetProp("RING"), + } + break + else: + raise AssertionError( + "Expected at least one atom with a functional group." + ) - num_of_nodes += 1 + self.num_of_nodes += 1 - fg_edges = {} - within_fg_edge_index = [[], []] - # TODO: Can we optimize this ? - for bond in bonds: - start_idx, end_idx = bond[:2] - for key, value in new_structure.items(): - if start_idx in value["atom"]: - source_fg = key - if end_idx in value["atom"]: - target_fg = key - within_fg_edge_index[0].extend([source_fg, target_fg]) - within_fg_edge_index[1].extend([target_fg, source_fg]) - fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} - num_of_edges += 1 - - graph_node = { - NODE_LEVEL: GRAPH_NODE_LEVEL, - "FG": "graph_fg", - "RING": "0", - } + return fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds - fg_graphNode_edges = {} - global_node_edge_index = [[], []] - for fg in new_structure.keys(): - global_node_edge_index[0].extend([num_of_nodes, fg]) - global_node_edge_index[1].extend([fg, num_of_nodes]) - fg_graphNode_edges[f"{num_of_nodes}_{fg}"] = {NODE_LEVEL: FG_GRAPHNODE_EDGE} - num_of_edges += 1 + def _construct_fg_level_structure( + self, structured_fg_map: dict, bonds: list + ) -> Tuple[List[List[int]], dict]: + """ + Constructs internal edges between functional group nodes based on bond connections. - all_edges = torch.cat( - [ - within_atoms_edge_index, - torch.tensor(fg_to_atoms_edge_index, dtype=torch.long), - torch.tensor(within_fg_edge_index, dtype=torch.long), - torch.tensor(global_node_edge_index, dtype=torch.long), - ], - dim=1, - ) + Args: + structured_fg_map (dict): Mapping from FG ID to atom indices. + bonds (list): List of bond tuples (source, target, ...). - augmented_graph_nodes = { - "atom_nodes": mol, - "fg_nodes": fg_nodes, - "graph_node": graph_node, - "num_nodes": num_of_nodes, - } - augmented_graph_edges = { - WITHIN_ATOMS_EDGE: mol, - ATOM_FG_EDGE: fg_atom_edges, - WITHIN_FG_EDGE: fg_edges, - FG_GRAPHNODE_EDGE: fg_graphNode_edges, - "num_edges": num_of_edges, - } + Returns: + Tuple[List[List[int]], dict]: Edge index and edge attribute dictionary. + """ + internal_fg_edges = {} + internal_edge_index = [[], []] - return all_edges, augmented_graph_nodes, augmented_graph_edges + for bond in bonds: + source_atom, target_atom = bond[:2] + source_fg, target_fg = None, None + + for fg_id, data in structured_fg_map.items(): + if source_atom in data["atom"]: + source_fg = fg_id + if target_atom in data["atom"]: + target_fg = fg_id + + assert ( + source_fg is not None and target_fg is not None + ), "Each bond should have a fg node on both end" + + internal_edge_index[0] += [source_fg, target_fg] + internal_edge_index[1] += [target_fg, source_fg] + internal_fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} + self.num_of_edges += 1 + + return internal_edge_index, internal_fg_edges + + def _construct_fg_to_graph_node_structure( + self, structured_fg_map: dict + ) -> Tuple[List[List[int]], dict, dict]: + """ + Constructs edges between functional group nodes and a global graph-level node. + + Args: + structured_fg_map (dict): Mapping from FG ID to atom indices. + + Returns: + Tuple[List[List[int]], dict, dict]: Edge index, graph-level node, edge attributes. + """ + graph_node = {NODE_LEVEL: GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} + + fg_graph_edges = {} + graph_edge_index = [[], []] + + for fg_id in structured_fg_map: + graph_edge_index[0] += [self.num_of_nodes, fg_id] + graph_edge_index[1] += [fg_id, self.num_of_nodes] + fg_graph_edges[f"{self.num_of_nodes}_{fg_id}"] = { + EDGE_LEVEL: FG_GRAPHNODE_EDGE + } + self.num_of_edges += 1 - def _get_fg_index(self, atom): - fg_group = atom.GetProp("FG") - if fg_group: - fg_index = self._get_token_index(fg_group) - return fg_index - else: - raise Exception("") - - def _get_ring_size(self, atom): - ring_size_str = atom.GetProp("RING") - if ring_size_str: - ring_sizes = list(map(int, ring_size_str.split("-"))) - # TODO: Decide ring size for atoms belongs to fused rings, rn only max ring size taken - return max(ring_sizes) - else: - return 0 + return graph_edge_index, graph_node, fg_graph_edges class RuleBasedFGReader(ChemDataReader): + """ + A reader which give numeric value for given functional group. + """ @classmethod def name(cls) -> str: + """ + Returns the name of the rule-based functional group reader. + + Returns: + str: The name of the reader. + """ return "rule_based_fg" - def _read_data(self, fg: str) -> int | None: + def _read_data(self, fg: str) -> Optional[int]: + """ + Reads and returns the token index for a given functional group. + + Args: + fg (str): The functional group to look up. + + Returns: + Optional[int]: The index of the functional group, or None if not found. + """ return self._get_token_index(fg) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index f29d109..8224d17 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -295,8 +295,7 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: def plot_augmented_graph( edge_index: Tensor, - augmented_graph_nodes: dict, - augmented_graph_edges: dict, + augmented_molecule: dict, mol: Mol, plot_type: str, ) -> None: @@ -305,12 +304,13 @@ def plot_augmented_graph( Args: edge_index (torch.Tensor): Edge indices tensor (2, num_edges). - augmented_graph_nodes (dict): Node metadata. - augmented_graph_edges (dict): Edge metadata. + augmented_molecule (dict): Augmented Molecule. mol (Chem.Mol): RDKit molecule object. plot_type (str): One of ["simple", "h", "3d"]. """ - G = _create_graph(edge_index, augmented_graph_nodes, augmented_graph_edges) + G = _create_graph( + edge_index, augmented_molecule["nodes"], augmented_molecule["edges"] + ) if plot_type == "h": _draw_hierarchy(G, mol) @@ -342,12 +342,10 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No - 3d: Hierarchical 3D-graph """ mol = self._fg_reader._smiles_to_mol(smiles) # noqa - edge_index, augmented_nodes, augmented_edges = self._fg_reader._augment_graph( + edge_index, augmented_molecule = self._fg_reader._create_augmented_graph( mol ) # noqa - plot_augmented_graph( - edge_index, augmented_nodes, augmented_edges, mol, plot_type - ) + plot_augmented_graph(edge_index, augmented_molecule, mol, plot_type) if __name__ == "__main__": From 21160c3765c7896abb349553ee148e0cd683f051 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 1 May 2025 19:14:25 +0200 Subject: [PATCH 049/224] rename to `fg_aware_rule_based` algorithm --- .../fg_detection/{rule_based.py => fg_aware_rule_based.py} | 0 chebai_graph/preprocessing/reader/augmented_reader.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename chebai_graph/preprocessing/fg_detection/{rule_based.py => fg_aware_rule_based.py} (100%) diff --git a/chebai_graph/preprocessing/fg_detection/rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py similarity index 100% rename from chebai_graph/preprocessing/fg_detection/rule_based.py rename to chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 2d279b0..797c1d3 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -8,7 +8,7 @@ from torch_geometric.data import Data as GeomData from chebai_graph.preprocessing.collate import GraphCollator -from chebai_graph.preprocessing.fg_detection.rule_based import ( +from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import ( detect_functional_group, get_structure, set_atom_map_num, From 9150bb6a91c636b76b8749f4b8c6f8f6beb5a42c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:18:55 +0200 Subject: [PATCH 050/224] Create .gitignore --- .gitignore | 170 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9676c5b --- /dev/null +++ b/.gitignore @@ -0,0 +1,170 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt +.isort.cfg From 06a71a66017708bd6ed9b21b9f13a64b2785e8c0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:23:11 +0200 Subject: [PATCH 051/224] update precommit + github action --- .github/workflows/black.yml | 10 ++++++++++ .pre-commit-config.yaml | 28 ++++++++++++++++++++++------ 2 files changed, 32 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/black.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..b04fb15 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 866c153..108b91d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,25 @@ repos: -#- repo: https://github.com/PyCQA/isort -# rev: "5.12.0" -# hooks: -# - id: isort - repo: https://github.com/psf/black - rev: "22.10.0" + rev: "24.2.0" hooks: - - id: black \ No newline at end of file + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace From ff1adc9acb877c56371f2a5c3b2d74269ec0da59 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:27:13 +0200 Subject: [PATCH 052/224] pre-commit format files --- README.md | 6 ++-- chebai_graph/models/gin_net.py | 9 +++--- chebai_graph/preprocessing/collate.py | 4 +-- chebai_graph/preprocessing/datasets/chebi.py | 32 ++++++++++--------- .../preprocessing/datasets/pubchem.py | 3 +- chebai_graph/preprocessing/properties.py | 6 ++-- .../preprocessing/property_encoder.py | 3 +- chebai_graph/preprocessing/reader.py | 21 ++++++------ .../preprocessing/transform_unlabeled.py | 1 + configs/data/chebi50_graph.yml | 2 +- configs/data/pubchem_graph.yml | 2 +- configs/loss/mask_pretraining.yml | 2 +- configs/model/gnn.yml | 2 +- configs/model/gnn_attention.yml | 2 +- configs/model/gnn_gine.yml | 2 +- configs/model/gnn_res_gated.yml | 2 +- configs/model/gnn_resgated_pretrain.yml | 2 +- pyproject.toml | 2 +- 18 files changed, 54 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 6af4630..c8ce94b 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ ## Installation -Some requirements may not be installed successfully automatically. +Some requirements may not be installed successfully automatically. To install the `torch-` libraries, use `pip install torch-${lib} -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html` where `${lib}` is either `scatter`, `geometric`, `sparse` or `cluster`, and -`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. +`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g. [torch-geometric docs](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)) @@ -31,7 +31,7 @@ We recommend the following setup: If you run the command from the `python-chebai` directory, you can use the same data for both chebai- and chebai-graph-models (e.g., Transformers and GNNs). Then you have to use `{path-to-chebai} -> .` and `{path-to-chebai-graph} -> ../python-chebai-graph`. - + Pretraining on a atom / bond masking task with PubChem data (feature-branch): ``` python3 -m chebai fit --model={path-to-chebai-graph}/configs/model/gnn_resgated_pretrain.yml --data={path-to-chebai-graph}/configs/data/pubchem_graph.yml --trainer={path-to-chebai}/configs/training/pretraining_trainer.yml diff --git a/chebai_graph/models/gin_net.py b/chebai_graph/models/gin_net.py index 75c2c45..6fed4c6 100644 --- a/chebai_graph/models/gin_net.py +++ b/chebai_graph/models/gin_net.py @@ -1,10 +1,11 @@ +import typing + +import torch +import torch.nn.functional as F +import torch_geometric from torch_scatter import scatter_add from chebai_graph.models.graph import GraphBaseNet -import torch_geometric -import torch.nn.functional as F -import torch -import typing class AggregateMLP(torch.nn.Module): diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 2c5f696..4be36cf 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,11 +1,11 @@ from typing import Dict import torch +from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData from torch_geometric.data.collate import collate as graph_collate -from chebai_graph.preprocessing.structures import XYGraphData -from chebai.preprocessing.collate import RaggedCollator +from chebai_graph.preprocessing.structures import XYGraphData class GraphCollator(RaggedCollator): diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 6ee8bc5..843ba35 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,26 +1,26 @@ -from typing import Optional, List, Callable +import importlib +import os +from typing import Callable, List, Optional +import pandas as pd +import torch +import tqdm +from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, ChEBIOverXPartial, ) -from chebai.preprocessing.datasets.base import XYBaseDataModule from lightning_utilities.core.rank_zero import rank_zero_info +from torch_geometric.data.data import Data as GeomData -from chebai_graph.preprocessing.reader import GraphReader, GraphPropertyReader +import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( AtomProperty, BondProperty, MolecularProperty, ) -import pandas as pd -from torch_geometric.data.data import Data as GeomData -import torch -import chebai_graph.preprocessing.properties as graph_properties -import importlib -import os -import tqdm +from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader class ChEBI50GraphData(ChEBIOver50): @@ -84,9 +84,11 @@ def _setup_properties(self): for file in file_names: # processed_dir_main only exists for ChEBI datasets path = os.path.join( - self.processed_dir_main - if hasattr(self, "processed_dir_main") - else self.raw_dir, + ( + self.processed_dir_main + if hasattr(self, "processed_dir_main") + else self.raw_dir + ), file, ) raw_data += list(self._load_dict(path)) @@ -94,8 +96,8 @@ def _setup_properties(self): features = [row["features"] for row in raw_data] # use vectorized version of encode function, apply only if value is present - enc_if_not_none = ( - lambda encode, value: [encode(atom_v) for atom_v in value] + enc_if_not_none = lambda encode, value: ( + [encode(atom_v) for atom_v in value] if value is not None and len(value) > 0 else None ) diff --git a/chebai_graph/preprocessing/datasets/pubchem.py b/chebai_graph/preprocessing/datasets/pubchem.py index 210b7ab..6f5d118 100644 --- a/chebai_graph/preprocessing/datasets/pubchem.py +++ b/chebai_graph/preprocessing/datasets/pubchem.py @@ -1,6 +1,7 @@ -from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn from chebai.preprocessing.datasets.pubchem import PubchemChem +from chebai_graph.preprocessing.datasets.chebi import GraphPropertiesMixIn + class PubChemGraphProperties(GraphPropertiesMixIn, PubchemChem): pass diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 95f85ab..9b927ed 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -6,11 +6,11 @@ from descriptastorus.descriptors import rdNormalizedDescriptors from chebai_graph.preprocessing.property_encoder import ( - PropertyEncoder, - IndexEncoder, - OneHotEncoder, AsIsEncoder, BoolEncoder, + IndexEncoder, + OneHotEncoder, + PropertyEncoder, ) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 497025c..ebfbe0c 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -1,8 +1,9 @@ import abc import os -import torch from typing import Optional +import torch + class PropertyEncoder(abc.ABC): def __init__(self, property, **kwargs): diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index b814d53..448f402 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,19 +1,18 @@ import importlib +import os +from typing import List, Mapping, Optional, Tuple -from torch_geometric.utils import from_networkx -from typing import Tuple, Mapping, Optional, List - -import importlib +import chebai.preprocessing.reader as dr import networkx as nx -import os -import torch -import rdkit.Chem as Chem import pysmiles as ps -import chebai.preprocessing.reader as dr -from chebai_graph.preprocessing.collate import GraphCollator -import chebai_graph.preprocessing.properties as properties +import rdkit.Chem as Chem +import torch +from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from lightning_utilities.core.rank_zero import rank_zero_warn, rank_zero_info +from torch_geometric.utils import from_networkx + +import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing.collate import GraphCollator class GraphPropertyReader(dr.ChemDataReader): diff --git a/chebai_graph/preprocessing/transform_unlabeled.py b/chebai_graph/preprocessing/transform_unlabeled.py index 3920659..0cc4b35 100644 --- a/chebai_graph/preprocessing/transform_unlabeled.py +++ b/chebai_graph/preprocessing/transform_unlabeled.py @@ -1,4 +1,5 @@ import random + import torch diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 14cc489..19c8753 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData \ No newline at end of file +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index af04491..c21f188 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -16,4 +16,4 @@ init_args: - chebai_graph.preprocessing.properties.BondInRing - chebai_graph.preprocessing.properties.BondAromaticity #- chebai_graph.preprocessing.properties.MoleculeNumRings - - chebai_graph.preprocessing.properties.RDKit2DNormalized \ No newline at end of file + - chebai_graph.preprocessing.properties.RDKit2DNormalized diff --git a/configs/loss/mask_pretraining.yml b/configs/loss/mask_pretraining.yml index c677559..6d2a560 100644 --- a/configs/loss/mask_pretraining.yml +++ b/configs/loss/mask_pretraining.yml @@ -1 +1 @@ -class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss \ No newline at end of file +class_path: chebai_graph.loss.pretraining.MaskPretrainingLoss diff --git a/configs/model/gnn.yml b/configs/model/gnn.yml index b0b119d..f85fa76 100644 --- a/configs/model/gnn.yml +++ b/configs/model/gnn.yml @@ -7,4 +7,4 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - n_linear_layers: 3 \ No newline at end of file + n_linear_layers: 3 diff --git a/configs/model/gnn_attention.yml b/configs/model/gnn_attention.yml index b1c553b..0c11ced 100644 --- a/configs/model/gnn_attention.yml +++ b/configs/model/gnn_attention.yml @@ -8,4 +8,4 @@ init_args: dropout_rate: 0.1 n_conv_layers: 5 n_linear_layers: 3 - n_heads: 5 \ No newline at end of file + n_heads: 5 diff --git a/configs/model/gnn_gine.yml b/configs/model/gnn_gine.yml index 0d0ed20..c84ea61 100644 --- a/configs/model/gnn_gine.yml +++ b/configs/model/gnn_gine.yml @@ -8,4 +8,4 @@ init_args: n_conv_layers: 5 n_linear_layers: 3 n_atom_properties: 125 - n_bond_properties: 5 \ No newline at end of file + n_bond_properties: 5 diff --git a/configs/model/gnn_res_gated.yml b/configs/model/gnn_res_gated.yml index d9ddc05..27d1e78 100644 --- a/configs/model/gnn_res_gated.yml +++ b/configs/model/gnn_res_gated.yml @@ -10,4 +10,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/configs/model/gnn_resgated_pretrain.yml b/configs/model/gnn_resgated_pretrain.yml index c26db76..fad8c27 100644 --- a/configs/model/gnn_resgated_pretrain.yml +++ b/configs/model/gnn_resgated_pretrain.yml @@ -13,4 +13,4 @@ init_args: n_linear_layers: 3 n_atom_properties: 151 n_bond_properties: 7 - n_molecule_properties: 200 \ No newline at end of file + n_molecule_properties: 200 diff --git a/pyproject.toml b/pyproject.toml index 64c572c..4aea1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ build-backend = "flit_core.buildapi" requires = ["flit_core >=3.2,<4"] [project.entry-points.'chebai.plugins'] -models = 'chebai_graph.models' \ No newline at end of file +models = 'chebai_graph.models' From d7f30d38c51235f4d21731360d11ff9ff0724344 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 16:31:50 +0200 Subject: [PATCH 053/224] change graph from directed to UNDIRECTED --- chebai_graph/preprocessing/reader.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 448f402..ced9f31 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -1,6 +1,5 @@ -import importlib import os -from typing import List, Mapping, Optional, Tuple +from typing import List, Optional import chebai.preprocessing.reader as dr import networkx as nx @@ -9,7 +8,7 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx +from torch_geometric.utils import from_networkx, to_undirected import chebai_graph.preprocessing.properties as properties from chebai_graph.preprocessing.collate import GraphCollator @@ -44,7 +43,7 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}") + rank_zero_warn(f"Rdkit failed at sanitizing {smiles} \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol @@ -64,7 +63,7 @@ def _read_data(self, raw_data): [bond.GetEndAtomIdx() for bond in mol.GetBonds()], ] ) - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + return GeomData(x=x, edge_index=to_undirected(edge_index), edge_attr=edge_attr) def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") From d3a6fe131364616e4ae7d87cafb5174731a1c7b2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 23:34:50 +0200 Subject: [PATCH 054/224] properties error fix --- chebai_graph/preprocessing/properties/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 45a5b65..fcee6d3 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -4,6 +4,8 @@ # isort: off from .properties import ( MolecularProperty, + AtomProperty, + BondProperty, AtomType, NumAtomBonds, AtomCharge, @@ -29,6 +31,8 @@ __all__ = [ "MolecularProperty", + "AtomProperty", + "BondProperty", "AtomType", "NumAtomBonds", "AtomCharge", From 05254708be7c3241104453929ad4013f2bed0907 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 8 May 2025 16:22:07 +0200 Subject: [PATCH 055/224] fix props error --- .../preprocessing/properties/augmented_properties.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 3b7833c..5b03d2a 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -97,7 +97,7 @@ def get_property_value(self, augmented_mol: Dict): prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] - graph_node = atom_molecule[self.MAIN_KEY]["graph_node"] + graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): raise TypeError( f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' @@ -107,8 +107,8 @@ def get_property_value(self, augmented_mol: Dict): # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes]) - prop_list.extend([self.get_atom_value(atom) for atom in graph_node]) + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) + prop_list.append(self.get_atom_value(graph_node)) return prop_list From 33bf005f0fe10cf9d067d7e159336ac32494bc7d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 00:03:30 +0200 Subject: [PATCH 056/224] make augment prop a subclass of atom, bond prop --- .../preprocessing/properties/augmented_properties.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 5b03d2a..134191a 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -3,12 +3,13 @@ from rdkit import Chem -from chebai_graph.preprocessing.properties import MolecularProperty -from chebai_graph.preprocessing.properties.constants import * from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder +from .constants import * +from .properties import AtomProperty, BondProperty -class AugmentedBondProperty(MolecularProperty, ABC): + +class AugmentedBondProperty(BondProperty, ABC): MAIN_KEY = "edges" def get_property_value(self, augmented_mol: Dict): @@ -73,7 +74,7 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") -class AugmentedAtomProperty(MolecularProperty, ABC): +class AugmentedAtomProperty(AtomProperty, ABC): MAIN_KEY = "nodes" def get_property_value(self, augmented_mol: Dict): From d07e45c2e95fb39f1e1e68fa9738a1d2119e2ce7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 00:04:30 +0200 Subject: [PATCH 057/224] remove rule based reader --- .../properties/augmented_properties.py | 9 ++---- .../preprocessing/reader/augmented_reader.py | 32 ++----------------- 2 files changed, 4 insertions(+), 37 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 134191a..fa27c17 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -147,19 +147,14 @@ class AtomFunctionalGroup(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) - # To avoid circular imports - from chebai_graph.preprocessing.reader import RuleBasedFGReader - - self.fg_reader = RuleBasedFGReader() - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, "FG") def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): if isinstance(atom, Chem.rdchem.Atom): - return self.fg_reader._read_data(atom.GetProp(prop)) # noqa + return atom.GetProp(prop) elif isinstance(atom, dict): - return self.fg_reader._read_data(atom[prop]) # noqa + return atom[prop] else: raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 797c1d3..8d39adf 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple import torch -from chebai.preprocessing.reader import ChemDataReader +from chebai.preprocessing.reader import DataReader from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from rdkit import Chem from torch_geometric.data import Data as GeomData @@ -17,7 +17,7 @@ from chebai_graph.preprocessing.properties.constants import * -class _AugmentorReader(ChemDataReader, ABC): +class _AugmentorReader(DataReader, ABC): """ Abstract base class for augmentor readers that extend ChemDataReader. Handles reading molecular data and augmenting molecules with functional group @@ -430,31 +430,3 @@ def _construct_fg_to_graph_node_structure( self.num_of_edges += 1 return graph_edge_index, graph_node, fg_graph_edges - - -class RuleBasedFGReader(ChemDataReader): - """ - A reader which give numeric value for given functional group. - """ - - @classmethod - def name(cls) -> str: - """ - Returns the name of the rule-based functional group reader. - - Returns: - str: The name of the reader. - """ - return "rule_based_fg" - - def _read_data(self, fg: str) -> Optional[int]: - """ - Reads and returns the token index for a given functional group. - - Args: - fg (str): The functional group to look up. - - Returns: - Optional[int]: The index of the functional group, or None if not found. - """ - return self._get_token_index(fg) From 2b8206a10e3e2a8dcbc8911f0cceba0d72dd2b8f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 00:06:41 +0200 Subject: [PATCH 058/224] add vscode to ignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9676c5b..3f0c0ab 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,4 @@ cython_debug/ /results_buffer electra_pretrained.ckpt .isort.cfg +/.vscode From 268141601bec42564abe44b58a04b458525ddf8b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 14:25:29 +0200 Subject: [PATCH 059/224] add data class for augmentation --- chebai_graph/preprocessing/datasets/chebi.py | 10 +++++++++- configs/data/chebi50_augmented_gnn.yml | 5 +---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 843ba35..dafcc2b 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -20,7 +20,11 @@ BondProperty, MolecularProperty, ) -from chebai_graph.preprocessing.reader import GraphPropertyReader, GraphReader +from chebai_graph.preprocessing.reader import ( + GraphFGAugmentorReader, + GraphPropertyReader, + GraphReader, +) class ChEBI50GraphData(ChEBIOver50): @@ -226,3 +230,7 @@ class ChEBI100GraphProperties(GraphPropertiesMixIn, ChEBIOver100): class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): pass + + +class ChEBI50GraphFGAugmentorReader(ChEBI50GraphProperties): + READER = GraphFGAugmentorReader diff --git a/configs/data/chebi50_augmented_gnn.yml b/configs/data/chebi50_augmented_gnn.yml index c748ac2..5cb9c38 100644 --- a/configs/data/chebi50_augmented_gnn.yml +++ b/configs/data/chebi50_augmented_gnn.yml @@ -1,7 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphProperties +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphFGAugmentorReader init_args: properties: - - chebai_graph.preprocessing.properties.AtomRingSize - - chebai_graph.preprocessing.properties.AtomNodeLevel - chebai_graph.preprocessing.properties.AtomFunctionalGroup - - chebai_graph.preprocessing.properties.BondLevel From 17115ffce0708aa251d3fc36b11bcde141bc46b1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 14:26:57 +0200 Subject: [PATCH 060/224] skip if reader return None, using walrus operator in python >=3.8 --- chebai_graph/preprocessing/datasets/chebi.py | 3 ++- pyproject.toml | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index dafcc2b..368664f 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -111,8 +111,9 @@ def _setup_properties(self): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode property_values = [ - self.reader.read_property(feat, property) + val for feat in tqdm.tqdm(features) + if (val := self.reader.read_property(feat, property)) is not None ] property.encoder.on_start(property_values=property_values) encoded_values = [ diff --git a/pyproject.toml b/pyproject.toml index 4aea1ae..01267a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "torch-cluster", "descriptastorus" ] +requires-python = ">=3.8" + [project.optional-dependencies] dev = [ From ce867cb126a56394d3ac507dd6a58371d9d1e520 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 14:28:24 +0200 Subject: [PATCH 061/224] molecular prop also accepts dict in case of augmentation --- chebai_graph/preprocessing/properties/properties.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index 06693fe..5d1e2fc 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Dict, Optional import numpy as np import rdkit.Chem as Chem @@ -32,7 +32,7 @@ def on_finish(self): def __str__(self): return self.name - def get_property_value(self, mol: Chem.rdchem.Mol): + def get_property_value(self, mol: Chem.rdchem.Mol | Dict): raise NotImplementedError From dfe51c9454d733c09451e6d27017ed662bfd6aa4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 May 2025 15:20:49 +0200 Subject: [PATCH 062/224] update augmented reader for wildcard smiles --- chebai_graph/preprocessing/reader/__init__.py | 3 +- .../preprocessing/reader/augmented_reader.py | 82 +++++++++++++------ 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index 737ee92..09946f0 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -1,9 +1,8 @@ -from .augmented_reader import GraphFGAugmentorReader, RuleBasedFGReader +from .augmented_reader import GraphFGAugmentorReader from .reader import GraphPropertyReader, GraphReader __all__ = [ "GraphReader", "GraphPropertyReader", "GraphFGAugmentorReader", - "RuleBasedFGReader", ] diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 8d39adf..51b0a3b 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -22,10 +22,6 @@ class _AugmentorReader(DataReader, ABC): Abstract base class for augmentor readers that extend ChemDataReader. Handles reading molecular data and augmenting molecules with functional group information. - - Attributes: - failed_counter (int): Counter for failed SMILES parsing attempts. - mol_object_buffer (dict): Cache for storing augmented molecular objects. """ COLLATOR = GraphCollator @@ -39,7 +35,12 @@ def __init__(self, *args, **kwargs): **kwargs: Additional keyword arguments passed to the ChemDataReader. """ super().__init__(*args, **kwargs) - self.failed_counter = 0 + self.f_cnt_for_smiles = ( + 0 # Record number of failures when constructing molecule from smiles + ) + self.f_cnt_for_aug_graph = ( + 0 # Record number of failure during augmented graph construction + ) self.mol_object_buffer = {} self.num_nodes = 0 self.num_of_edges = 0 @@ -56,15 +57,15 @@ def name(cls) -> str: pass @abstractmethod - def _create_augmented_graph(self, smile: str) -> Tuple[Dict, torch.Tensor]: + def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, Dict]: """ Augments a molecule represented by a SMILES string. Args: - smile (str): SMILES string representing the molecule. + mol (Chem.Mol): RDKIT molecule. Returns: - Tuple[Dict, torch.Tensor]: Augmented molecule information and corresponding edge index. + Tuple[torch.Tensor, Dict]: Graph edge index and augmented molecule information """ pass @@ -77,11 +78,11 @@ def _read_data(self, raw_data: str) -> GeomData: raw_data (str): Raw data input. Returns: - List[int]: Processed data as a list of integers. + GeomData: `torch_geometric.data.Data` object. """ pass - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.Mol]: + def _smiles_to_mol(self, smiles: str) -> Chem.Mol: """ Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. @@ -89,25 +90,28 @@ def _smiles_to_mol(self, smiles: str) -> Optional[Chem.Mol]: smiles (str): SMILES string representing the molecule. Returns: - Optional[Chem.Mol]: RDKit molecule object if conversion is successful, else None. + Chem.Mol: RDKit molecule object. """ mol = Chem.MolFromSmiles(smiles) if mol is None: rank_zero_warn(f"RDKit failed to parse {smiles} (returned None)") - self.failed_counter += 1 + self.f_cnt_for_smiles += 1 else: try: Chem.SanitizeMol(mol) except Exception as e: rank_zero_warn(f"RDKit failed at sanitizing {smiles}, Error {e}") - self.failed_counter += 1 + self.f_cnt_for_smiles += 1 return mol - def on_finish(self): + def on_finish(self) -> None: """ - Finalizes the reading process and logs the number of failed SMILES. + Finalizes the reading process and logs the number of failed SMILES and failed augmentation. """ - rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") + rank_zero_info(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") + rank_zero_info( + f"Failed to construct augmented graph for {self.f_cnt_for_aug_graph} number of SMILES" + ) self.mol_object_buffer = {} def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: @@ -121,15 +125,15 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[Li Returns: Optional[List]: Property values if molecule parsing is successful, else None. """ + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + mol = self._smiles_to_mol(smiles) if mol is None: return None - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - - augmented_mol, _ = self._create_augmented_graph(smiles) - return property.get_property_value(mol) + _, augmented_mol = self._create_augmented_graph(mol) + return property.get_property_value(augmented_mol) class GraphFGAugmentorReader(_AugmentorReader): @@ -164,7 +168,14 @@ def _read_data(self, smiles: str) -> GeomData | None: if mol is None: return None - edge_index, augmented_molecule = self._create_augmented_graph(mol) + returned_result = self._create_augmented_graph(mol) + if returned_result is None: + rank_zero_info(f"Failed to construct augmented graph for smiles {smiles}") + self.f_cnt_for_aug_graph += 1 + return None + + edge_index, augmented_molecule = returned_result + self.mol_object_buffer[smiles] = augmented_molecule num_nodes = augmented_molecule["nodes"]["num_nodes"] @@ -176,7 +187,9 @@ def _read_data(self, smiles: str) -> GeomData | None: return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) - def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: + def _create_augmented_graph( + self, mol: Chem.Mol + ) -> Optional[Tuple[torch.Tensor, Dict]]: """ Generates an augmented graph from a SMILES string. @@ -186,7 +199,11 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: Returns: Tuple[dict, torch.Tensor]: Augmented molecule information and edge index. """ - edge_index, node_info, edge_info = self._augment_graph_structure(mol) + returned_result = self._augment_graph_structure(mol) + if returned_result is None: + return None + + edge_index, node_info, edge_info = returned_result augmented_molecule = {"nodes": node_info, "edges": edge_info} @@ -194,7 +211,7 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: def _augment_graph_structure( self, mol: Chem.Mol - ) -> Tuple[torch.Tensor, dict, dict]: + ) -> Optional[Tuple[torch.Tensor, dict, dict]]: """ Constructs the full augmented graph structure from a molecule. @@ -211,9 +228,15 @@ def _augment_graph_structure( atom_edge_index = self._generate_atom_level_edge_index(mol) # Create FG-level structure and edges + returned_result = self._construct_fg_to_atom_structure(mol) + + if returned_result is None: + return None + fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds = ( - self._construct_fg_to_atom_structure(mol) + returned_result ) + fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( structured_fg_map, bonds ) @@ -282,7 +305,7 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: def _construct_fg_to_atom_structure( self, mol: Chem.Mol - ) -> Tuple[List[List[int]], dict, dict, dict, list]: + ) -> Optional[Tuple[List[List[int]], dict, dict, dict, list]]: """ Constructs edges between functional group (FG) nodes and atom nodes. @@ -346,6 +369,11 @@ def _construct_fg_to_atom_structure( mol.GetAtomWithIdx(i).GetProp("FG") for i in structure[fg_key]["atom"] } + + if "" in fg_set and len(fg_set) == 1: + # There will be no FGs for wildcard SMILES Eg. CHEBI:33429 + return None + if "" in fg_set or len(fg_set) > 1: raise ValueError("Invalid functional group assignment to atoms.") From 8fd45d1e94bd5ba49faaf1129dcaa9ba94e83198 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 21:10:08 +0200 Subject: [PATCH 063/224] fix to cal right number of nodes --- .../preprocessing/reader/augmented_reader.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 51b0a3b..2a1306d 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -42,8 +42,8 @@ def __init__(self, *args, **kwargs): 0 # Record number of failure during augmented graph construction ) self.mol_object_buffer = {} - self.num_nodes = 0 - self.num_of_edges = 0 + self._num_of_nodes = 0 + self._num_of_edges = 0 @classmethod @abstractmethod @@ -132,7 +132,11 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[Li if mol is None: return None - _, augmented_mol = self._create_augmented_graph(mol) + returned_result = self._create_augmented_graph(mol) + if returned_result is None: + return None + + _, augmented_mol = returned_result return property.get_property_value(augmented_mol) @@ -175,15 +179,11 @@ def _read_data(self, smiles: str) -> GeomData | None: return None edge_index, augmented_molecule = returned_result - self.mol_object_buffer[smiles] = augmented_molecule - num_nodes = augmented_molecule["nodes"]["num_nodes"] - num_edges = augmented_molecule["edges"]["num_edges"] - # Empty features initialized; node and edge features can be added later - x = torch.zeros((num_nodes, 0)) - edge_attr = torch.zeros((num_edges, 0)) + x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) + edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"], 0)) return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) @@ -221,8 +221,8 @@ def _augment_graph_structure( Returns: Tuple[torch.Tensor, dict, dict]: Edge index, node metadata, and edge metadata. """ - self.num_of_nodes = mol.GetNumAtoms() - self.num_of_edges = mol.GetNumBonds() + self._num_of_nodes = mol.GetNumAtoms() + self._num_of_edges = mol.GetNumBonds() self._annotate_atoms_and_bonds(mol) atom_edge_index = self._generate_atom_level_edge_index(mol) @@ -259,14 +259,14 @@ def _augment_graph_structure( "atom_nodes": mol, "fg_nodes": fg_nodes, "graph_node": graph_node, - "num_nodes": self.num_of_nodes, + "num_nodes": self._num_of_nodes, } edge_info = { WITHIN_ATOMS_EDGE: mol, ATOM_FG_EDGE: atom_fg_edges, WITHIN_FG_EDGE: internal_fg_edges, FG_GRAPHNODE_EDGE: fg_to_graph_edges, - "num_edges": self.num_of_edges, + "num_edges": self._num_of_edges, } return full_edge_index, node_info, edge_info @@ -331,16 +331,16 @@ def _construct_fg_to_atom_structure( ) # Contains augmented fg-nodes and connected atoms indices for idx, fg_key in enumerate(structure): - structured_fg_map[self.num_of_nodes] = {"atom": structure[fg_key]["atom"]} + structured_fg_map[self._num_of_nodes] = {"atom": structure[fg_key]["atom"]} # Build edge index for fg to atom nodes connections for atom_idx in structure[fg_key]["atom"]: - fg_atom_edge_index[0] += [self.num_of_nodes, atom_idx] - fg_atom_edge_index[1] += [atom_idx, self.num_of_nodes] - atom_fg_edges[f"{self.num_of_nodes}_{atom_idx}"] = { + fg_atom_edge_index[0] += [self._num_of_nodes, atom_idx] + fg_atom_edge_index[1] += [atom_idx, self._num_of_nodes] + atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { EDGE_LEVEL: ATOM_FG_EDGE } - self.num_of_edges += 1 + self._num_of_edges += 1 # Identify ring vs. functional group type ring_fg = { @@ -358,7 +358,7 @@ def _construct_fg_to_atom_structure( len(ring_fg) == 1 ): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = next(iter(ring_fg)) - fg_nodes[self.num_of_nodes] = { + fg_nodes[self._num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring "FG": f"RING_{ring_size}", @@ -380,7 +380,7 @@ def _construct_fg_to_atom_structure( for atom_idx in structure[fg_key]["atom"]: atom = mol.GetAtomWithIdx(atom_idx) if atom.GetProp("FG"): - fg_nodes[self.num_of_nodes] = { + fg_nodes[self._num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, "FG": atom.GetProp("FG"), "RING": atom.GetProp("RING"), @@ -391,7 +391,7 @@ def _construct_fg_to_atom_structure( "Expected at least one atom with a functional group." ) - self.num_of_nodes += 1 + self._num_of_nodes += 1 return fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds @@ -428,7 +428,7 @@ def _construct_fg_level_structure( internal_edge_index[0] += [source_fg, target_fg] internal_edge_index[1] += [target_fg, source_fg] internal_fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} - self.num_of_edges += 1 + self._num_of_edges += 1 return internal_edge_index, internal_fg_edges @@ -450,11 +450,12 @@ def _construct_fg_to_graph_node_structure( graph_edge_index = [[], []] for fg_id in structured_fg_map: - graph_edge_index[0] += [self.num_of_nodes, fg_id] - graph_edge_index[1] += [fg_id, self.num_of_nodes] - fg_graph_edges[f"{self.num_of_nodes}_{fg_id}"] = { + graph_edge_index[0] += [self._num_of_nodes, fg_id] + graph_edge_index[1] += [fg_id, self._num_of_nodes] + fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { EDGE_LEVEL: FG_GRAPHNODE_EDGE } - self.num_of_edges += 1 + self._num_of_edges += 1 + self._num_of_nodes += 1 return graph_edge_index, graph_node, fg_graph_edges From e45c7c183e51841c204d5b251d3362a8618358b6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 21:11:36 +0200 Subject: [PATCH 064/224] assert if num of prop values is equal to num of nodes/edges --- .../preprocessing/properties/augmented_properties.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index fa27c17..ebbf455 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Dict, List, Optional from rdkit import Chem @@ -12,7 +12,7 @@ class AugmentedBondProperty(BondProperty, ABC): MAIN_KEY = "edges" - def get_property_value(self, augmented_mol: Dict): + def get_property_value(self, augmented_mol: Dict) -> List: if self.MAIN_KEY not in augmented_mol: raise KeyError( f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" @@ -50,6 +50,9 @@ def get_property_value(self, augmented_mol: Dict): prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges]) prop_list.extend([self.get_bond_value(bond) for bond in fg_edges]) prop_list.extend([self.get_bond_value(bond) for bond in fg_graph_node_edges]) + assert ( + len(prop_list) == augmented_mol[self.MAIN_KEY]["num_edges"] + ), "Number of property values should be equal to number of edges" return prop_list @@ -110,6 +113,9 @@ def get_property_value(self, augmented_mol: Dict): # https://mail.python.org/pipermail/python-dev/2017-December/151283.html prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) prop_list.append(self.get_atom_value(graph_node)) + assert ( + len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] + ), "Number of property values should be equal to number of nodes" return prop_list From 011ed737aa078b08832362544be4c7fa2732f860 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 13:36:40 +0200 Subject: [PATCH 065/224] take value from reader if not None --- chebai_graph/preprocessing/datasets/chebi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 368664f..7244e20 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -111,10 +111,10 @@ def _setup_properties(self): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode property_values = [ - val + self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) - if (val := self.reader.read_property(feat, property)) is not None ] + property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -149,10 +149,10 @@ def setup(self, **kwargs): def _merge_props_into_base(self, row): geom_data = row["features"] + assert isinstance(geom_data, GeomData) edge_attr = geom_data.edge_attr x = geom_data.x molecule_attr = torch.empty((1, 0)) - assert isinstance(geom_data, GeomData) for property in self.properties: property_values = row[f"{property.name}"] if isinstance(property_values, torch.Tensor): @@ -233,5 +233,5 @@ class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): pass -class ChEBI50GraphFGAugmentorReader(ChEBI50GraphProperties): +class ChEBI50GraphFGAugmentorReader(GraphPropertiesMixIn, ChEBIOver50): READER = GraphFGAugmentorReader From b8189d183190e559f6a1288e4bae17af8a4971d0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 13:37:46 +0200 Subject: [PATCH 066/224] add test data --- tests/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_data.py | 119 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_data.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..aea97a0 --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,119 @@ +import torch +from torch_geometric.data import Data + + +class MoleculeGraph: + """Dummy graph of Aspirin with node and edge features""" + + def get_aspirin_graph(self): + """ + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 + + Node labels (atom indices): + O2 C5———C6 + \ / \ + C1———O3———C4 C7 + / \ / + C0 C9———C8 + / + C10 + / \ + O12 O11 + """ + + # --- Node features: atomic numbers (C=6, O=8) --- + # Shape of x : num_nodes x num_of_node_features + x = torch.tensor( + [ + [6], # C0 - This feature belongs to atom with atom `0` in edge_index + [6], # C1 - This feature belongs to atom with atom `1` in edge_index + [8], # O2 - This feature belongs to atom with atom `2` in edge_index + [8], # O3 - This feature belongs to atom with atom `3` in edge_index + [6], # C4 - This feature belongs to atom with atom `4` in edge_index + [6], # C5 - This feature belongs to atom with atom `5` in edge_index + [6], # C6 - This feature belongs to atom with atom `6` in edge_index + [6], # C7 - This feature belongs to atom with atom `7` in edge_index + [6], # C8 - This feature belongs to atom with atom `8` in edge_index + [6], # C9 - This feature belongs to atom with atom `9` in edge_index + [6], # C10 - This feature belongs to atom with atom `10` in edge_index + [8], # O11 - This feature belongs to atom with atom `11` in edge_index + [8], # O12 - This feature belongs to atom with atom `12` in edge_index + ], + dtype=torch.float, + ) + + # --- Edge list (bidirectional) --- + # Shape of edge_index for undirected graph: 2 x num_of_edges + edge_index = ( + torch.tensor( + [ + [0, 1], + [1, 0], + [1, 2], + [2, 1], + [1, 3], + [3, 1], + [3, 4], + [4, 3], + [4, 5], + [5, 4], + [5, 6], + [6, 5], + [6, 7], + [7, 6], + [7, 8], + [8, 7], + [8, 9], + [9, 8], + [4, 9], + [9, 4], + [9, 10], + [10, 9], + [10, 11], + [11, 10], + [10, 12], + [12, 10], + ], + dtype=torch.long, + ) + .t() + .contiguous() + ) + + # --- Dummy edge features: bond type (single=1, double=2, ester=3) --- + # Using all single bonds for simplicity (except C=O as double bonds) + # Shape of edge_attr: num_of_edges x num_of_edges_features + edge_attr = torch.tensor( + [ + [1], + [1], # C0 - C1 # This two features to two first bond in + [2], + [2], # C1 = O2 (double bond) + [1], + [1], # C1 - O3 + [1], + [1], # O3 - C4 + [1], + [1], # C4 - C5 + [1], + [1], # C5 - C6 + [1], + [1], # C6 - C7 + [1], + [1], # C7 - C8 + [1], + [1], # C8 - C9 + [1], + [1], # C4 - C9 (ring closure) + [1], + [1], # C9 - C10 + [2], + [2], # C10 = O11 (carboxylic acid) + [1], + [1], # C10 - O12 (hydroxyl) + ], + dtype=torch.float, + ) + + # Create graph data object + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) From ad301e6dda67ed7b067937301931844ccc4ff2d6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 13:45:34 +0200 Subject: [PATCH 067/224] edge_features should be calculated after undirected graph --- chebai_graph/preprocessing/reader.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index ced9f31..9862fe9 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -55,15 +55,19 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_attr = torch.zeros((mol.GetNumBonds(), 0)) - - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] + edge_index = to_undirected( + torch.tensor( + [ + [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], + [bond.GetEndAtomIdx() for bond in mol.GetBonds()], + ] + ) ) - return GeomData(x=x, edge_index=to_undirected(edge_index), edge_attr=edge_attr) + + # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] + edge_attr = torch.zeros((edge_index.size(1), 0)) + + return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) def on_finish(self): rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") From 344d828adb73b5cb28c5ff04dca122287dd7650e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 16:51:26 +0200 Subject: [PATCH 068/224] directed edge which form an un-dir edge should be adjancent --- chebai_graph/preprocessing/reader.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 9862fe9..ac6f212 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -55,14 +55,12 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - edge_index = to_undirected( - torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) - ) + # We need to ensure that directed edges which form a undirected edge are adjacent to each other + edge_index_list = [[], []] + for bond in mol.GetBonds(): + edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) + edge_index = torch.tensor(edge_index_list, dtype=torch.long) # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) From 8a69828818dc09046806f06c7c8a15e18676f3fa Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 16:59:51 +0200 Subject: [PATCH 069/224] add test for GraphPropertyReader --- tests/unit/readers/__init__.py | 0 tests/unit/readers/testGraphPropertyReader.py | 61 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 tests/unit/readers/__init__.py create mode 100644 tests/unit/readers/testGraphPropertyReader.py diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py new file mode 100644 index 0000000..c3771c3 --- /dev/null +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -0,0 +1,61 @@ +import unittest + +import torch +from torch_geometric.data import Data as GeomData + +from chebai_graph.preprocessing.reader import GraphPropertyReader +from tests.unit.test_data import MoleculeGraph + + +class TestGraphPropertyReader(unittest.TestCase): + """Unit tests for the GraphPropertyReader class, which converts SMILES strings to torch_geometric Data objects.""" + + def setUp(self) -> None: + """Initialize the reader and the reference molecule graph.""" + self.reader: GraphPropertyReader = GraphPropertyReader() + self.molecule_graph: MoleculeGraph = MoleculeGraph() + + def test_read_data(self) -> None: + """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" + smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin + + data: GeomData = self.reader._read_data(smiles) + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + + self.assertIsInstance( + data, + GeomData, + msg="The output should be an instance of torch_geometric.data.Data.", + ) + + self.assertTrue( + torch.equal(data.edge_index, expected_data.edge_index), + msg=( + "edge_index tensors do not match.\n" + f"Differences at indices: {(data.edge_index != expected_data.edge_index).nonzero()}.\n" + f"Parsed edge_index:\n{data.edge_index}\nExpected edge_index:\n{expected_data.edge_index}" + f"If fails in future, check if there is change in RDKIT version, the expected graph is generated with RDKIT 2024.9.6" + ), + ) + + self.assertEqual( + data.x.shape[0], + expected_data.x.shape[0], + msg=( + "The number of atoms (nodes) in the parsed graph does not match the reference graph.\n" + f"Parsed: {data.x.shape[0]}, Expected: {expected_data.x.shape[0]}" + ), + ) + + self.assertEqual( + data.edge_attr.shape[0], + expected_data.edge_attr.shape[0], + msg=( + "The number of edge attributes does not match the expected value.\n" + f"Parsed: {data.edge_attr.shape[0]}, Expected: {expected_data.edge_attr.shape[0]}" + ), + ) + + +if __name__ == "__main__": + unittest.main() From e0064b837394b0341b34bb3969ce738c15d074fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 17:18:14 +0200 Subject: [PATCH 070/224] add gt test data for aspirin --- tests/unit/test_data.py | 163 ++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 83 deletions(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index aea97a0..8aff95b 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -3,10 +3,12 @@ class MoleculeGraph: - """Dummy graph of Aspirin with node and edge features""" + """Class representing molecular graph data.""" def get_aspirin_graph(self): """ + Constructs and returns a PyTorch Geometric Data object representing the molecular graph of Aspirin. + Aspirin -> CC(=O)OC1=CC=CC=C1C(=O)O ; CHEBI:15365 Node labels (atom indices): @@ -19,101 +21,96 @@ def get_aspirin_graph(self): C10 / \ O12 O11 + + + Returns: + torch_geometric.data.Data: A Data object with attributes: + - x (FloatTensor): Node feature matrix of shape (num_nodes, 1). + - edge_index (LongTensor): Graph connectivity in COO format of shape (2, num_edges). + - edge_attr (FloatTensor): Edge feature matrix of shape (num_edges, 1). + + Refer: + For graph construction: https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html """ # --- Node features: atomic numbers (C=6, O=8) --- # Shape of x : num_nodes x num_of_node_features x = torch.tensor( [ - [6], # C0 - This feature belongs to atom with atom `0` in edge_index - [6], # C1 - This feature belongs to atom with atom `1` in edge_index - [8], # O2 - This feature belongs to atom with atom `2` in edge_index - [8], # O3 - This feature belongs to atom with atom `3` in edge_index - [6], # C4 - This feature belongs to atom with atom `4` in edge_index - [6], # C5 - This feature belongs to atom with atom `5` in edge_index - [6], # C6 - This feature belongs to atom with atom `6` in edge_index - [6], # C7 - This feature belongs to atom with atom `7` in edge_index - [6], # C8 - This feature belongs to atom with atom `8` in edge_index - [6], # C9 - This feature belongs to atom with atom `9` in edge_index - [6], # C10 - This feature belongs to atom with atom `10` in edge_index - [8], # O11 - This feature belongs to atom with atom `11` in edge_index - [8], # O12 - This feature belongs to atom with atom `12` in edge_index + [ + 6 + ], # C0 - This feature belongs to atom/node with `0` value in edge_index + [ + 6 + ], # C1 - This feature belongs to atom/node with `1` value in edge_index + [ + 8 + ], # O2 - This feature belongs to atom/node with `2` value in edge_index + [ + 8 + ], # O3 - This feature belongs to atom/node with `3` value in edge_index + [ + 6 + ], # C4 - This feature belongs to atom/node with `4` value in edge_index + [ + 6 + ], # C5 - This feature belongs to atom/node with `5` value in edge_index + [ + 6 + ], # C6 - This feature belongs to atom/node with `6` value in edge_index + [ + 6 + ], # C7 - This feature belongs to atom/node with `7` value in edge_index + [ + 6 + ], # C8 - This feature belongs to atom/node with `8` value in edge_index + [ + 6 + ], # C9 - This feature belongs to atom/node with `9` value in edge_index + [ + 6 + ], # C10 - This feature belongs to atom/node with `10` value in edge_index + [ + 8 + ], # O11 - This feature belongs to atom/node with `11` value in edge_index + [ + 8 + ], # O12 - This feature belongs to atom/node with `12` value in edge_index ], dtype=torch.float, ) # --- Edge list (bidirectional) --- - # Shape of edge_index for undirected graph: 2 x num_of_edges - edge_index = ( - torch.tensor( - [ - [0, 1], - [1, 0], - [1, 2], - [2, 1], - [1, 3], - [3, 1], - [3, 4], - [4, 3], - [4, 5], - [5, 4], - [5, 6], - [6, 5], - [6, 7], - [7, 6], - [7, 8], - [8, 7], - [8, 9], - [9, 8], - [4, 9], - [9, 4], - [9, 10], - [10, 9], - [10, 11], - [11, 10], - [10, 12], - [12, 10], - ], - dtype=torch.long, - ) - .t() - .contiguous() - ) + # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) + # 2 directed edges of one undirected edge are adjacent to each other --- this is needed + + # fmt: off + # Generated using RDKIT 2024.9.6 + edge_index = torch.tensor([ + [0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 10, 12, 9, 4], # Start atoms (u) + [1, 0, 2, 1, 3, 1, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 10, 4, 9] # End atoms (v) + ], dtype=torch.long) + # fmt: on - # --- Dummy edge features: bond type (single=1, double=2, ester=3) --- - # Using all single bonds for simplicity (except C=O as double bonds) + # --- Dummy edge features --- # Shape of edge_attr: num_of_edges x num_of_edges_features - edge_attr = torch.tensor( - [ - [1], - [1], # C0 - C1 # This two features to two first bond in - [2], - [2], # C1 = O2 (double bond) - [1], - [1], # C1 - O3 - [1], - [1], # O3 - C4 - [1], - [1], # C4 - C5 - [1], - [1], # C5 - C6 - [1], - [1], # C6 - C7 - [1], - [1], # C7 - C8 - [1], - [1], # C8 - C9 - [1], - [1], # C4 - C9 (ring closure) - [1], - [1], # C9 - C10 - [2], - [2], # C10 = O11 (carboxylic acid) - [1], - [1], # C10 - O12 (hydroxyl) - ], - dtype=torch.float, - ) + # fmt: off + edge_attr = torch.tensor([ + [1], [1], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index` + [2], [2], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index` + [2], [2], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index` + [2], [2], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index` + [1], [1], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index` + [1], [1], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index` + [1], [1], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index` + [1], [1], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index` + [1], [1], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index` + [1], [1], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index` + [1], [1], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index` + [1], [1], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index` + [1], [1], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index` + ], dtype=torch.float) + # fmt: on # Create graph data object return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) From a9c722888e846c10d4ec82953ce99ef2a601555a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 17:33:18 +0200 Subject: [PATCH 071/224] Update test_data.py --- tests/unit/test_data.py | 54 ++++++++++++----------------------------- 1 file changed, 15 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 8aff95b..cd3b16f 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -35,50 +35,26 @@ def get_aspirin_graph(self): # --- Node features: atomic numbers (C=6, O=8) --- # Shape of x : num_nodes x num_of_node_features + # fmt: off x = torch.tensor( [ - [ - 6 - ], # C0 - This feature belongs to atom/node with `0` value in edge_index - [ - 6 - ], # C1 - This feature belongs to atom/node with `1` value in edge_index - [ - 8 - ], # O2 - This feature belongs to atom/node with `2` value in edge_index - [ - 8 - ], # O3 - This feature belongs to atom/node with `3` value in edge_index - [ - 6 - ], # C4 - This feature belongs to atom/node with `4` value in edge_index - [ - 6 - ], # C5 - This feature belongs to atom/node with `5` value in edge_index - [ - 6 - ], # C6 - This feature belongs to atom/node with `6` value in edge_index - [ - 6 - ], # C7 - This feature belongs to atom/node with `7` value in edge_index - [ - 6 - ], # C8 - This feature belongs to atom/node with `8` value in edge_index - [ - 6 - ], # C9 - This feature belongs to atom/node with `9` value in edge_index - [ - 6 - ], # C10 - This feature belongs to atom/node with `10` value in edge_index - [ - 8 - ], # O11 - This feature belongs to atom/node with `11` value in edge_index - [ - 8 - ], # O12 - This feature belongs to atom/node with `12` value in edge_index + [6], # C0 - This feature belongs to atom/node with 0 value in edge_index + [6], # C1 - This feature belongs to atom/node with 1 value in edge_index + [8], # O2 - This feature belongs to atom/node with 2 value in edge_index + [8], # O3 - This feature belongs to atom/node with 3 value in edge_index + [6], # C4 - This feature belongs to atom/node with 4 value in edge_index + [6], # C5 - This feature belongs to atom/node with 5 value in edge_index + [6], # C6 - This feature belongs to atom/node with 6 value in edge_index + [6], # C7 - This feature belongs to atom/node with 7 value in edge_index + [6], # C8 - This feature belongs to atom/node with 8 value in edge_index + [6], # C9 - This feature belongs to atom/node with 9 value in edge_index + [6], # C10 - This feature belongs to atom/node with 10 value in edge_index + [8], # O11 - This feature belongs to atom/node with 11 value in edge_index + [8], # O12 - This feature belongs to atom/node with 12 value in edge_index ], dtype=torch.float, ) + # fmt: on # --- Edge list (bidirectional) --- # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) From fa8dcc5c9bc17eb95fb7f317cb5675cb77f2a09b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 19:46:02 +0200 Subject: [PATCH 072/224] dir edges of undir should be adjacent --- .../preprocessing/reader/augmented_reader.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 2a1306d..aca453c 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -295,13 +295,12 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: Returns: torch.Tensor: Bidirectional edge index tensor. """ - edge_index = torch.tensor( - [ - [bond.GetBeginAtomIdx() for bond in mol.GetBonds()], - [bond.GetEndAtomIdx() for bond in mol.GetBonds()], - ] - ) - return torch.cat([edge_index, edge_index[[1, 0], :]], dim=1) + # We need to ensure that directed edges which form a undirected edge are adjacent to each other + edge_index_list = [[], []] + for bond in mol.GetBonds(): + edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) + return torch.tensor(edge_index_list, dtype=torch.long) def _construct_fg_to_atom_structure( self, mol: Chem.Mol From 1ef714fa4558c895e8e8c8d4ced2d892e6f3c80f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 19:48:51 +0200 Subject: [PATCH 073/224] assert checks for aug reader --- .../preprocessing/reader/augmented_reader.py | 41 +++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index aca453c..334dad5 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -183,7 +183,19 @@ def _read_data(self, smiles: str) -> GeomData | None: # Empty features initialized; node and edge features can be added later x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) - edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"], 0)) + edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"] * 2, 0)) + + assert ( + edge_index.shape[0] == 2 + ), f"Expected edge_index to have shape [2, num_edges], but got shape {edge_index.shape}" + + assert ( + edge_index.shape[1] == edge_attr.shape[0] + ), f"Mismatch between number of edges in edge_index ({edge_index.shape[1]}) and edge_attr ({edge_attr.shape[0]})" + + assert ( + len(set(edge_index[0].tolist())) == x.shape[0] + ), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) @@ -329,13 +341,13 @@ def _construct_fg_to_atom_structure( {} ) # Contains augmented fg-nodes and connected atoms indices - for idx, fg_key in enumerate(structure): - structured_fg_map[self._num_of_nodes] = {"atom": structure[fg_key]["atom"]} + for fg_group in structure.values(): + structured_fg_map[self._num_of_nodes] = {"atom": fg_group["atom"]} # Build edge index for fg to atom nodes connections - for atom_idx in structure[fg_key]["atom"]: - fg_atom_edge_index[0] += [self._num_of_nodes, atom_idx] - fg_atom_edge_index[1] += [atom_idx, self._num_of_nodes] + for atom_idx in fg_group["atom"]: + fg_atom_edge_index[0].extend([self._num_of_nodes, atom_idx]) + fg_atom_edge_index[1].extend([atom_idx, self._num_of_nodes]) atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { EDGE_LEVEL: ATOM_FG_EDGE } @@ -344,7 +356,7 @@ def _construct_fg_to_atom_structure( # Identify ring vs. functional group type ring_fg = { mol.GetAtomWithIdx(i).GetProp("RING") - for i in structure[fg_key]["atom"] + for i in fg_group["atom"] if mol.GetAtomWithIdx(i).GetProp("RING") } @@ -364,10 +376,7 @@ def _construct_fg_to_atom_structure( "RING": ring_size, } else: # No connected has a ring size which indicates it is simple FG - fg_set = { - mol.GetAtomWithIdx(i).GetProp("FG") - for i in structure[fg_key]["atom"] - } + fg_set = {mol.GetAtomWithIdx(i).GetProp("FG") for i in fg_group["atom"]} if "" in fg_set and len(fg_set) == 1: # There will be no FGs for wildcard SMILES Eg. CHEBI:33429 @@ -376,7 +385,7 @@ def _construct_fg_to_atom_structure( if "" in fg_set or len(fg_set) > 1: raise ValueError("Invalid functional group assignment to atoms.") - for atom_idx in structure[fg_key]["atom"]: + for atom_idx in fg_group["atom"]: atom = mol.GetAtomWithIdx(atom_idx) if atom.GetProp("FG"): fg_nodes[self._num_of_nodes] = { @@ -424,8 +433,8 @@ def _construct_fg_level_structure( source_fg is not None and target_fg is not None ), "Each bond should have a fg node on both end" - internal_edge_index[0] += [source_fg, target_fg] - internal_edge_index[1] += [target_fg, source_fg] + internal_edge_index[0].extend([source_fg, target_fg]) + internal_edge_index[1].extend([target_fg, source_fg]) internal_fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} self._num_of_edges += 1 @@ -449,8 +458,8 @@ def _construct_fg_to_graph_node_structure( graph_edge_index = [[], []] for fg_id in structured_fg_map: - graph_edge_index[0] += [self._num_of_nodes, fg_id] - graph_edge_index[1] += [fg_id, self._num_of_nodes] + graph_edge_index[0].extend([self._num_of_nodes, fg_id]) + graph_edge_index[1].extend([fg_id, self._num_of_nodes]) fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { EDGE_LEVEL: FG_GRAPHNODE_EDGE } From 0a9760ddd3e4b4c7e3b4bd4227a91f4ff8b00d8a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 20:30:06 +0200 Subject: [PATCH 074/224] add more graph test --- tests/unit/readers/testGraphPropertyReader.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py index c3771c3..a4f18eb 100644 --- a/tests/unit/readers/testGraphPropertyReader.py +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -19,8 +19,7 @@ def test_read_data(self) -> None: """Test that the reader correctly parses a SMILES string into a graph and matches expected aspirin structure.""" smiles: str = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin - data: GeomData = self.reader._read_data(smiles) - expected_data: GeomData = self.molecule_graph.get_aspirin_graph() + data: GeomData = self.reader._read_data(smiles) # noqa self.assertIsInstance( data, @@ -28,6 +27,19 @@ def test_read_data(self) -> None: msg="The output should be an instance of torch_geometric.data.Data.", ) + assert ( + data.edge_index.shape[0] == 2 + ), f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}" + + assert ( + data.edge_index.shape[1] == data.edge_attr.shape[0] + ), f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})" + + assert ( + len(set(data.edge_index[0].tolist())) == data.x.shape[0] + ), f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})" + + expected_data: GeomData = self.molecule_graph.get_aspirin_graph() self.assertTrue( torch.equal(data.edge_index, expected_data.edge_index), msg=( From 1a8dcb60897bfb5027d731647cc1955cc561b849 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:16:21 +0200 Subject: [PATCH 075/224] first src to tgt edges then tgt to src - instead of using adjacent directed edge, this one is better approach since we can stack edge attributes generated later without any further logic to rearrange edge_attr --- chebai_graph/preprocessing/reader.py | 13 +++---- tests/unit/test_data.py | 54 ++++++++++++++++------------ 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index ac6f212..6cd4ecd 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -8,9 +8,8 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData -from torch_geometric.utils import from_networkx, to_undirected -import chebai_graph.preprocessing.properties as properties +from chebai_graph.preprocessing import properties from chebai_graph.preprocessing.collate import GraphCollator @@ -55,12 +54,10 @@ def _read_data(self, raw_data): x = torch.zeros((mol.GetNumAtoms(), 0)) - # We need to ensure that directed edges which form a undirected edge are adjacent to each other - edge_index_list = [[], []] - for bond in mol.GetBonds(): - edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) - edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) - edge_index = torch.tensor(edge_index_list, dtype=torch.long) + # First source to target edges, then target to source edges + src = [bond.GetBeginAtomIdx() for bond in mol.GetBonds()] + tgt = [bond.GetEndAtomIdx() for bond in mol.GetBonds()] + edge_index = torch.tensor([src + tgt, tgt + src], dtype=torch.long) # edge_index.shape == [2, num_edges]; edge_attr.shape == [num_edges, num_edge_features] edge_attr = torch.zeros((edge_index.size(1), 0)) diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index cd3b16f..4acf41d 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -58,35 +58,45 @@ def get_aspirin_graph(self): # --- Edge list (bidirectional) --- # Shape of edge_index for undirected graph: 2 x num_of_edges; (2x26) - # 2 directed edges of one undirected edge are adjacent to each other --- this is needed - - # fmt: off # Generated using RDKIT 2024.9.6 - edge_index = torch.tensor([ - [0, 1, 1, 2, 1, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 10, 12, 9, 4], # Start atoms (u) - [1, 0, 2, 1, 3, 1, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8, 10, 9, 11, 10, 12, 10, 4, 9] # End atoms (v) + # fmt: off + _edge_index = torch.tensor([ + [0, 1, 1, 3, 4, 5, 6, 7, 8, 9, 10, 10, 9], # Start atoms (u) + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 4] # End atoms (v) ], dtype=torch.long) # fmt: on + # Reverse the edges + reversed_edge_index = _edge_index[[1, 0], :] + + # First all directed edges from source to target are placed, + # then all directed edges from target to source are placed --- this is needed + undirected_edge_index = torch.cat([_edge_index, reversed_edge_index], dim=1) + # --- Dummy edge features --- - # Shape of edge_attr: num_of_edges x num_of_edges_features + # Shape of undirected_edge_attr: num_of_edges x num_of_edges_features (26 x 1) # fmt: off - edge_attr = torch.tensor([ - [1], [1], # C0 - C1, This two features belong to elements at index 0 and 1 in `edge_index` - [2], [2], # C1 - C2, This two features belong to elements at index 2 and 3 in `edge_index` - [2], [2], # C1 - O3, This two features belong to elements at index 4 and 5 in `edge_index` - [2], [2], # O3 - C4, This two features belong to elements at index 6 and 7 in `edge_index` - [1], [1], # C4 - C5, This two features belong to elements at index 8 and 9 in `edge_index` - [1], [1], # C5 - C6, This two features belong to elements at index 10 and 11 in `edge_index` - [1], [1], # C6 - C7, This two features belong to elements at index 12 and 13 in `edge_index` - [1], [1], # C7 - C8, This two features belong to elements at index 14 and 15 in `edge_index` - [1], [1], # C8 - C9, This two features belong to elements at index 16 and 17 in `edge_index` - [1], [1], # C9 - C10, This two features belong to elements at index 18 and 19 in `edge_index` - [1], [1], # C10 - O11, This two features belong to elements at index 20 and 21 in `edge_index` - [1], [1], # C10 - O12, This two features belong to elements at index 22 and 23 in `edge_index` - [1], [1], # C9 - C4, This two features belong to elements at index 24 and 25 in `edge_index` + _edge_attr = torch.tensor([ + [1], # C0 - C1, This two features belong to elements at index 0 in `edge_index` + [2], # C1 - C2, This two features belong to elements at index 1 in `edge_index` + [2], # C1 - O3, This two features belong to elements at index 2 in `edge_index` + [2], # O3 - C4, This two features belong to elements at index 3 in `edge_index` + [1], # C4 - C5, This two features belong to elements at index 4 in `edge_index` + [1], # C5 - C6, This two features belong to elements at index 5 in `edge_index` + [1], # C6 - C7, This two features belong to elements at index 6 in `edge_index` + [1], # C7 - C8, This two features belong to elements at index 7 in `edge_index` + [1], # C8 - C9, This two features belong to elements at index 8 in `edge_index` + [1], # C9 - C10, This two features belong to elements at index 9 in `edge_index` + [1], # C10 - O11, This two features belong to elements at index 10 in `edge_index` + [1], # C10 - O12, This two features belong to elements at index 11 in `edge_index` + [1], # C9 - C4, This two features belong to elements at index 12 in `edge_index` ], dtype=torch.float) # fmt: on + # Alignement of edge attributes should in same order as of edge_index + undirected_edge_attr = torch.cat([_edge_attr, _edge_attr], dim=0) + # Create graph data object - return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + return Data( + x=x, edge_index=undirected_edge_index, edge_attr=undirected_edge_attr + ) From 5d4c174e314ebede5283ded03fd88a794dd49aa9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:16:50 +0200 Subject: [PATCH 076/224] add test for duplicate directed edges --- tests/unit/readers/testGraphPropertyReader.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/unit/readers/testGraphPropertyReader.py b/tests/unit/readers/testGraphPropertyReader.py index a4f18eb..0222fa4 100644 --- a/tests/unit/readers/testGraphPropertyReader.py +++ b/tests/unit/readers/testGraphPropertyReader.py @@ -27,17 +27,30 @@ def test_read_data(self) -> None: msg="The output should be an instance of torch_geometric.data.Data.", ) - assert ( - data.edge_index.shape[0] == 2 - ), f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}" + self.assertEqual( + data.edge_index.shape[0], + 2, + msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}", + ) + + self.assertEqual( + data.edge_index.shape[1], + data.edge_attr.shape[0], + msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})", + ) - assert ( - data.edge_index.shape[1] == data.edge_attr.shape[0] - ), f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})" + self.assertEqual( + len(set(data.edge_index[0].tolist())), + data.x.shape[0], + msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})", + ) - assert ( - len(set(data.edge_index[0].tolist())) == data.x.shape[0] - ), f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})" + # Check for duplicates by checking if the rows are the same (direction matters) + _, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True) + self.assertFalse( + torch.any(counts > 1), + msg="There are duplicates of directed edge in edge_index", + ) expected_data: GeomData = self.molecule_graph.get_aspirin_graph() self.assertTrue( From 945ef7c34e77bd1ebddf23c76c4fb0e6997b9ce0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:20:53 +0200 Subject: [PATCH 077/224] restore import --- chebai_graph/preprocessing/reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chebai_graph/preprocessing/reader.py b/chebai_graph/preprocessing/reader.py index 6cd4ecd..687b199 100644 --- a/chebai_graph/preprocessing/reader.py +++ b/chebai_graph/preprocessing/reader.py @@ -8,6 +8,7 @@ import torch from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData +from torch_geometric.utils import from_networkx from chebai_graph.preprocessing import properties from chebai_graph.preprocessing.collate import GraphCollator From 53a240ad1b478bea92efa3cecce7ee2d382d74ea Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 14 May 2025 23:24:14 +0200 Subject: [PATCH 078/224] concat edge attr for undirected graph --- chebai_graph/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 843ba35..f36fb64 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -168,7 +168,7 @@ def _merge_props_into_base(self, row): return GeomData( x=x, edge_index=geom_data.edge_index, - edge_attr=edge_attr, + edge_attr=torch.cat([edge_attr, edge_attr], dim=0), molecule_attr=molecule_attr, ) From b1f2da373de060702ee67b26b1e97c6049ff70cf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 May 2025 11:56:35 +0200 Subject: [PATCH 079/224] concat prop values instead of edge_attr --- chebai_graph/preprocessing/datasets/chebi.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index f36fb64..da5445e 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -162,13 +162,17 @@ def _merge_props_into_base(self, row): if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): - edge_attr = torch.cat([edge_attr, property_values], dim=1) + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) return GeomData( x=x, edge_index=geom_data.edge_index, - edge_attr=torch.cat([edge_attr, edge_attr], dim=0), + edge_attr=edge_attr, molecule_attr=molecule_attr, ) From bb133bc221edbb66c76c1b08c9e97b95627be60b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 21 May 2025 13:02:20 +0200 Subject: [PATCH 080/224] edge_index: first src to tgt edges, then tgt to src edges --- .../preprocessing/reader/augmented_reader.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 334dad5..4a8994f 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -183,7 +183,7 @@ def _read_data(self, smiles: str) -> GeomData | None: # Empty features initialized; node and edge features can be added later x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) - edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"] * 2, 0)) + edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"], 0)) assert ( edge_index.shape[0] == 2 @@ -257,7 +257,7 @@ def _augment_graph_structure( ) # Merge all edge types - full_edge_index = torch.cat( + directed_edge_index = torch.cat( [ atom_edge_index, torch.tensor(fg_atom_edge_index, dtype=torch.long), @@ -266,6 +266,11 @@ def _augment_graph_structure( ], dim=1, ) + # First all directed edges from source to target are placed, then all directed edges from target to source + # are placed --- this is needed as it is easier to align the property values in same way + undirected_edge_index = torch.cat( + [directed_edge_index, directed_edge_index[[1, 0], :]], dim=1 + ) node_info = { "atom_nodes": mol, @@ -278,10 +283,10 @@ def _augment_graph_structure( ATOM_FG_EDGE: atom_fg_edges, WITHIN_FG_EDGE: internal_fg_edges, FG_GRAPHNODE_EDGE: fg_to_graph_edges, - "num_edges": self._num_of_edges, + "num_edges": self._num_of_edges * 2, # Undirected edges } - return full_edge_index, node_info, edge_info + return undirected_edge_index, node_info, edge_info @staticmethod def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: @@ -310,8 +315,8 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: # We need to ensure that directed edges which form a undirected edge are adjacent to each other edge_index_list = [[], []] for bond in mol.GetBonds(): - edge_index_list[0].extend([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) - edge_index_list[1].extend([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) + edge_index_list[0].append(bond.GetBeginAtomIdx()) + edge_index_list[1].append(bond.GetEndAtomIdx()) return torch.tensor(edge_index_list, dtype=torch.long) def _construct_fg_to_atom_structure( @@ -346,8 +351,8 @@ def _construct_fg_to_atom_structure( # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: - fg_atom_edge_index[0].extend([self._num_of_nodes, atom_idx]) - fg_atom_edge_index[1].extend([atom_idx, self._num_of_nodes]) + fg_atom_edge_index[0].append(self._num_of_nodes) + fg_atom_edge_index[1].append(atom_idx) atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { EDGE_LEVEL: ATOM_FG_EDGE } @@ -433,8 +438,8 @@ def _construct_fg_level_structure( source_fg is not None and target_fg is not None ), "Each bond should have a fg node on both end" - internal_edge_index[0].extend([source_fg, target_fg]) - internal_edge_index[1].extend([target_fg, source_fg]) + internal_edge_index[0].append(source_fg) + internal_edge_index[1].append(target_fg) internal_fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} self._num_of_edges += 1 @@ -458,8 +463,8 @@ def _construct_fg_to_graph_node_structure( graph_edge_index = [[], []] for fg_id in structured_fg_map: - graph_edge_index[0].extend([self._num_of_nodes, fg_id]) - graph_edge_index[1].extend([fg_id, self._num_of_nodes]) + graph_edge_index[0].append(self._num_of_nodes) + graph_edge_index[1].append(fg_id) fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { EDGE_LEVEL: FG_GRAPHNODE_EDGE } From 53ca4387a6292b651142040ffddb46d85b0cb25b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 23 May 2025 19:21:48 +0200 Subject: [PATCH 081/224] inherit from ChebiOverX instead of Base data module, as `load_processed_data_from_file` method used in this class is available in Dynamic dataset class --- chebai_graph/preprocessing/datasets/chebi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index b132856..8e6599f 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,14 +1,15 @@ import importlib import os +from abc import ABC from typing import Callable, List, Optional import pandas as pd import torch import tqdm -from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, + ChEBIOverX, ChEBIOverXPartial, ) from lightning_utilities.core.rank_zero import rank_zero_info @@ -48,7 +49,7 @@ def _resolve_property( return getattr(graph_properties, property)() -class GraphPropertiesMixIn(XYBaseDataModule): +class GraphPropertiesMixIn(ChEBIOverX, ABC): READER = GraphPropertyReader def __init__( From 7b0b9ff4b58237f5c74c74a88f2b954c8a47fa30 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 23 May 2025 19:32:01 +0200 Subject: [PATCH 082/224] add augmented props to config --- configs/data/chebi50_augmented_gnn.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/configs/data/chebi50_augmented_gnn.yml b/configs/data/chebi50_augmented_gnn.yml index 5cb9c38..8eee34b 100644 --- a/configs/data/chebi50_augmented_gnn.yml +++ b/configs/data/chebi50_augmented_gnn.yml @@ -1,4 +1,9 @@ class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphFGAugmentorReader init_args: properties: + # Atom properties - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.AtomNodeLevel + - chebai_graph.preprocessing.properties.AtomRingSize + # Bond properties + - chebai_graph.preprocessing.properties.BondLevel From 1d5698b088e2466c3cefe21c6492a11c42438f63 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 24 May 2025 21:06:36 +0200 Subject: [PATCH 083/224] gat wrapper --- chebai_graph/models/_gat.py | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 chebai_graph/models/_gat.py diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py new file mode 100644 index 0000000..54d2eca --- /dev/null +++ b/chebai_graph/models/_gat.py @@ -0,0 +1,43 @@ +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.models import GAT + +from .graph import GraphBaseNet + + +class GATModelWrapper(GraphBaseNet): + def __init__(self, config: dict, **kwargs): + super().__init__(**kwargs) + + self._in_length = config["in_length"] + self._hidden_length = config["hidden_length"] + self._dropout_rate = config["dropout_rate"] + self._n_conv_layers = ( + config["n_conv_layers"] if "n_conv_layers" in config else 3 + ) + self._n_linear_layers = ( + config["n_linear_layers"] if "n_linear_layers" in config else 3 + ) + self._n_atom_properties = int(config["n_atom_properties"]) + self._n_bond_properties = ( + int(config["n_bond_properties"]) if "n_bond_properties" in config else 7 + ) + self._n_molecule_properties = ( + int(config["n_molecule_properties"]) + if "n_molecule_properties" in config + else 0 + ) + self._gat = GAT( + in_channels=self._in_length, + hidden_channels=self._hidden_length, + num_layers=self._n_conv_layers, + dropout=self._dropout_rate, + **kwargs, + ) + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + x = graph_data.x.float() + return self._gat.forward( + x=x, edge_index=graph_data.edge_index, edge_attr=graph_data.edge_attr + ) From 0e6c4811e85e5b3672d87a636e616e573e291775 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 24 May 2025 21:34:57 +0200 Subject: [PATCH 084/224] add linear layers to gat --- chebai_graph/models/_gat.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py index 54d2eca..252e2ed 100644 --- a/chebai_graph/models/_gat.py +++ b/chebai_graph/models/_gat.py @@ -1,5 +1,7 @@ +import torch from torch_geometric.data import Data as GraphData from torch_geometric.nn.models import GAT +from torch_scatter import scatter_add from .graph import GraphBaseNet @@ -34,10 +36,28 @@ def __init__(self, config: dict, **kwargs): **kwargs, ) + self.linear_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + self.gnn.hidden_length + (i == 0) * self.gnn.n_molecule_properties, + self.gnn.hidden_length, + ) + for i in range(self._n_linear_layers - 1) + ] + ) + self.final_layer = torch.nn.Linear(self._hidden_length, self.out_dim) + def forward(self, batch): graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) x = graph_data.x.float() - return self._gat.forward( - x=x, edge_index=graph_data.edge_index, edge_attr=graph_data.edge_attr + a = self._gat.forward( + x=x, edge_index=graph_data.edge_index.long(), edge_attr=graph_data.edge_attr ) + a = scatter_add(a, graph_data.batch, dim=0) + + a = torch.cat([a, graph_data.molecule_attr], dim=1) + + for lin in self.linear_layers: + a = self.gnn.activation(lin(a)) + a = self.final_layer(a) From 98fa827da011c7d8eda1d8ad9e5d2cbd26c9a611 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 24 May 2025 23:19:56 +0200 Subject: [PATCH 085/224] gat ffn fix --- chebai_graph/models/_gat.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py index 252e2ed..d8eec53 100644 --- a/chebai_graph/models/_gat.py +++ b/chebai_graph/models/_gat.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from torch_geometric.data import Data as GraphData from torch_geometric.nn.models import GAT from torch_scatter import scatter_add @@ -36,11 +37,14 @@ def __init__(self, config: dict, **kwargs): **kwargs, ) + self._ffn_activation = F.elu + self.linear_layers = torch.nn.ModuleList( [ torch.nn.Linear( - self.gnn.hidden_length + (i == 0) * self.gnn.n_molecule_properties, - self.gnn.hidden_length, + self._hidden_length + + (self._n_molecule_properties if i == 0 else 0), + self._hidden_length, ) for i in range(self._n_linear_layers - 1) ] @@ -59,5 +63,5 @@ def forward(self, batch): a = torch.cat([a, graph_data.molecule_attr], dim=1) for lin in self.linear_layers: - a = self.gnn.activation(lin(a)) - a = self.final_layer(a) + a = self._ffn_activation(lin(a)) + return self.final_layer(a) From 4319e4731a76bc800c009354601825f0dbedd340 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 11:24:40 +0200 Subject: [PATCH 086/224] `nan_to_num` numpy2.x compatibility fix for https://github.com/ChEB-AI/python-chebai-graph/issues/10 --- chebai_graph/preprocessing/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/properties.py b/chebai_graph/preprocessing/properties.py index 29808cd..2b3acf8 100644 --- a/chebai_graph/preprocessing/properties.py +++ b/chebai_graph/preprocessing/properties.py @@ -155,5 +155,5 @@ def get_property_value(self, mol: Chem.rdchem.Mol): features_normalized = generator_normalized.processMol( mol, Chem.MolToSmiles(mol) ) - np.nan_to_num(features_normalized, copy=False) + features_normalized = np.nan_to_num(features_normalized) return [features_normalized[1:]] From 7c0a484edb86a8d9297360350e913afe5d3c0ca9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 18:59:03 +0200 Subject: [PATCH 087/224] add print statements --- chebai_graph/preprocessing/datasets/chebi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 8e6599f..9bddad7 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -107,10 +107,12 @@ def _setup_properties(self): if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode + rank_zero_info(f"\tReading property valeus...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] + rank_zero_info(f"\tEncoding property values...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) @@ -151,6 +153,7 @@ def _merge_props_into_base(self, row): assert isinstance(geom_data, GeomData) for property in self.properties: property_values = row[f"{property.name}"] + rank_zero_info(f"Merging {property.name} into base dataframe...") if isinstance(property_values, torch.Tensor): if len(property_values.size()) == 0: property_values = property_values.unsqueeze(0) From ffc0b7529671561944635bbc9b95bccffcc3cb22 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 25 May 2025 19:18:55 +0200 Subject: [PATCH 088/224] remove print for dataloader phase --- chebai_graph/preprocessing/datasets/chebi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 9bddad7..2721b1c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -153,7 +153,6 @@ def _merge_props_into_base(self, row): assert isinstance(geom_data, GeomData) for property in self.properties: property_values = row[f"{property.name}"] - rank_zero_info(f"Merging {property.name} into base dataframe...") if isinstance(property_values, torch.Tensor): if len(property_values.size()) == 0: property_values = property_values.unsqueeze(0) From f2cafbf082f8fb4275bcbc47d2529a6e5765c509 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 26 May 2025 13:45:05 +0200 Subject: [PATCH 089/224] gat model pop config if used to avoid unnecessary kwargs getting passed to GAT --- chebai_graph/models/_gat.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py index d8eec53..833477b 100644 --- a/chebai_graph/models/_gat.py +++ b/chebai_graph/models/_gat.py @@ -8,33 +8,25 @@ class GATModelWrapper(GraphBaseNet): + NAME = "GATModel" + def __init__(self, config: dict, **kwargs): super().__init__(**kwargs) - self._in_length = config["in_length"] - self._hidden_length = config["hidden_length"] - self._dropout_rate = config["dropout_rate"] - self._n_conv_layers = ( - config["n_conv_layers"] if "n_conv_layers" in config else 3 - ) - self._n_linear_layers = ( - config["n_linear_layers"] if "n_linear_layers" in config else 3 - ) - self._n_atom_properties = int(config["n_atom_properties"]) - self._n_bond_properties = ( - int(config["n_bond_properties"]) if "n_bond_properties" in config else 7 - ) - self._n_molecule_properties = ( - int(config["n_molecule_properties"]) - if "n_molecule_properties" in config - else 0 - ) + self._in_length = int(config.pop("in_length")) + self._hidden_length = int(config.pop("hidden_length")) + self._dropout_rate = float(config.pop("dropout_rate", 0.1)) + self._n_conv_layers = int(config.pop("n_conv_layers", 3)) + self._n_linear_layers = int(config.pop("n_linear_layers", 3)) + self._n_atom_properties = int(config.pop("n_atom_properties", 0)) + self._n_bond_properties = int(config.pop("n_bond_properties", 0)) + self._n_molecule_properties = int(config.pop("n_molecule_properties", 0)) self._gat = GAT( in_channels=self._in_length, hidden_channels=self._hidden_length, num_layers=self._n_conv_layers, dropout=self._dropout_rate, - **kwargs, + **config, ) self._ffn_activation = F.elu From 39fadacc9f997cb3ea3b871ed4ae5cb9f884697c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 26 May 2025 13:45:15 +0200 Subject: [PATCH 090/224] gat config --- configs/model/gat.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 configs/model/gat.yml diff --git a/configs/model/gat.yml b/configs/model/gat.yml new file mode 100644 index 0000000..3e980db --- /dev/null +++ b/configs/model/gat.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATModelWrapper +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_length: 256 + hidden_length: 512 + dropout_rate: 0.1 + n_conv_layers: 3 + heads: 5 # Default is one + # v2: True # -- to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + n_linear_layers: 3 + n_atom_properties: 158 + n_bond_properties: 7 + n_molecule_properties: 200 From abe65b45ae4e558720b0e6a364ccf4d1e5a83d6c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 26 May 2025 13:52:22 +0200 Subject: [PATCH 091/224] add GAT wrapper to init models --- chebai_graph/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index e69de29..2a46f65 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -0,0 +1,3 @@ +from ._gat import GATModelWrapper + +__all__ = ["GATModelWrapper"] From dc4d3a7f02edeabe07154c94f0884295aab03f85 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 26 May 2025 14:38:01 +0200 Subject: [PATCH 092/224] heads should be divisible by output or hidden channels --- configs/model/gat.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/model/gat.yml b/configs/model/gat.yml index 3e980db..8ce3ad3 100644 --- a/configs/model/gat.yml +++ b/configs/model/gat.yml @@ -7,7 +7,7 @@ init_args: hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - heads: 5 # Default is one + heads: 4 # the number of heads should be divisible by output channels (hidden channels if output channel not given) # v2: True # -- to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv n_linear_layers: 3 n_atom_properties: 158 From 66ea24832d8e6977f93495941707b7a1ec98a920 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 11:06:17 +0200 Subject: [PATCH 093/224] why in_length is needed ? if n_atom_properties is available --- chebai_graph/models/_gat.py | 4 ++-- configs/model/gat.yml | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py index 833477b..da61235 100644 --- a/chebai_graph/models/_gat.py +++ b/chebai_graph/models/_gat.py @@ -13,7 +13,6 @@ class GATModelWrapper(GraphBaseNet): def __init__(self, config: dict, **kwargs): super().__init__(**kwargs) - self._in_length = int(config.pop("in_length")) self._hidden_length = int(config.pop("hidden_length")) self._dropout_rate = float(config.pop("dropout_rate", 0.1)) self._n_conv_layers = int(config.pop("n_conv_layers", 3)) @@ -22,10 +21,11 @@ def __init__(self, config: dict, **kwargs): self._n_bond_properties = int(config.pop("n_bond_properties", 0)) self._n_molecule_properties = int(config.pop("n_molecule_properties", 0)) self._gat = GAT( - in_channels=self._in_length, + in_channels=self._n_atom_properties, hidden_channels=self._hidden_length, num_layers=self._n_conv_layers, dropout=self._dropout_rate, + edge_dim=self._n_bond_properties, **config, ) diff --git a/configs/model/gat.yml b/configs/model/gat.yml index 8ce3ad3..0cc9b24 100644 --- a/configs/model/gat.yml +++ b/configs/model/gat.yml @@ -3,11 +3,10 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - in_length: 256 hidden_length: 512 dropout_rate: 0.1 n_conv_layers: 3 - heads: 4 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) # v2: True # -- to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv n_linear_layers: 3 n_atom_properties: 158 From 6bc4f35bd70e16093ab5af2019d2ee2ce570f508 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 16:07:11 +0200 Subject: [PATCH 094/224] fg level edge fix If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted. --- chebai_graph/preprocessing/reader/augmented_reader.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 4a8994f..be79c73 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -440,8 +440,12 @@ def _construct_fg_level_structure( internal_edge_index[0].append(source_fg) internal_edge_index[1].append(target_fg) - internal_fg_edges[f"{source_fg}_{target_fg}"] = {EDGE_LEVEL: WITHIN_FG_EDGE} - self._num_of_edges += 1 + edge_str = f"{source_fg}_{target_fg}" + if edge_str not in internal_fg_edges: + # If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted. + # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG + internal_fg_edges[edge_str] = {EDGE_LEVEL: WITHIN_FG_EDGE} + self._num_of_edges += 1 return internal_edge_index, internal_fg_edges From 4fb00e66edf9201b1e2f5243f93d69ed4f535d06 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 16:18:08 +0200 Subject: [PATCH 095/224] assert statements for num of edges and nodes --- .../preprocessing/reader/augmented_reader.py | 61 +++++++++++-------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index be79c73..ed31432 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -35,12 +35,10 @@ def __init__(self, *args, **kwargs): **kwargs: Additional keyword arguments passed to the ChemDataReader. """ super().__init__(*args, **kwargs) - self.f_cnt_for_smiles = ( - 0 # Record number of failures when constructing molecule from smiles - ) - self.f_cnt_for_aug_graph = ( - 0 # Record number of failure during augmented graph construction - ) + # Record number of failures when constructing molecule from smiles + self.f_cnt_for_smiles = 0 + # Record number of failure during augmented graph construction + self.f_cnt_for_aug_graph = 0 self.mol_object_buffer = {} self._num_of_nodes = 0 self._num_of_edges = 0 @@ -245,15 +243,15 @@ def _augment_graph_structure( if returned_result is None: return None - fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds = ( + fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds = ( returned_result ) fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( - structured_fg_map, bonds + fg_to_atoms_map, bonds ) fg_graph_edge_index, graph_node, fg_to_graph_edges = ( - self._construct_fg_to_graph_node_structure(structured_fg_map) + self._construct_fg_to_graph_node_structure(fg_to_atoms_map) ) # Merge all edge types @@ -272,20 +270,35 @@ def _augment_graph_structure( [directed_edge_index, directed_edge_index[[1, 0], :]], dim=1 ) + total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes), 1]) + assert ( + self._num_of_nodes == total_atoms + ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._num_of_nodes}" node_info = { "atom_nodes": mol, "fg_nodes": fg_nodes, "graph_node": graph_node, "num_nodes": self._num_of_nodes, } + + total_edges = sum( + [ + mol.GetNumBonds(), + len(atom_fg_edges), + len(internal_fg_edges), + len(fg_to_graph_edges), + ] + ) + assert ( + self._num_of_edges == total_edges + ), f"Mismatch in number of edges: expected {total_edges}, got {self._num_of_edges}" edge_info = { WITHIN_ATOMS_EDGE: mol, ATOM_FG_EDGE: atom_fg_edges, WITHIN_FG_EDGE: internal_fg_edges, FG_GRAPHNODE_EDGE: fg_to_graph_edges, - "num_edges": self._num_of_edges * 2, # Undirected edges + "num_undirected_edges": self._num_of_edges * 2, # Undirected edges } - return undirected_edge_index, node_info, edge_info @staticmethod @@ -342,12 +355,11 @@ def _construct_fg_to_atom_structure( fg_atom_edge_index = [[], []] fg_nodes, atom_fg_edges = {}, {} - structured_fg_map = ( - {} - ) # Contains augmented fg-nodes and connected atoms indices + # Contains augmented fg-nodes and connected atoms indices + fg_to_atoms_map = {} for fg_group in structure.values(): - structured_fg_map[self._num_of_nodes] = {"atom": fg_group["atom"]} + fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: @@ -370,9 +382,8 @@ def _construct_fg_to_atom_structure( "A functional group must not span multiple ring sizes." ) - if ( - len(ring_fg) == 1 - ): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings + if len(ring_fg) == 1: + # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = next(iter(ring_fg)) fg_nodes[self._num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, @@ -406,16 +417,16 @@ def _construct_fg_to_atom_structure( self._num_of_nodes += 1 - return fg_atom_edge_index, fg_nodes, atom_fg_edges, structured_fg_map, bonds + return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds def _construct_fg_level_structure( - self, structured_fg_map: dict, bonds: list + self, fg_to_atoms_map: dict, bonds: list ) -> Tuple[List[List[int]], dict]: """ Constructs internal edges between functional group nodes based on bond connections. Args: - structured_fg_map (dict): Mapping from FG ID to atom indices. + fg_to_atoms_map (dict): Mapping from FG ID to atom indices. bonds (list): List of bond tuples (source, target, ...). Returns: @@ -428,7 +439,7 @@ def _construct_fg_level_structure( source_atom, target_atom = bond[:2] source_fg, target_fg = None, None - for fg_id, data in structured_fg_map.items(): + for fg_id, data in fg_to_atoms_map.items(): if source_atom in data["atom"]: source_fg = fg_id if target_atom in data["atom"]: @@ -450,13 +461,13 @@ def _construct_fg_level_structure( return internal_edge_index, internal_fg_edges def _construct_fg_to_graph_node_structure( - self, structured_fg_map: dict + self, fg_to_atoms_map: dict ) -> Tuple[List[List[int]], dict, dict]: """ Constructs edges between functional group nodes and a global graph-level node. Args: - structured_fg_map (dict): Mapping from FG ID to atom indices. + fg_to_atoms_map (dict): Mapping from FG ID to atom indices. Returns: Tuple[List[List[int]], dict, dict]: Edge index, graph-level node, edge attributes. @@ -466,7 +477,7 @@ def _construct_fg_to_graph_node_structure( fg_graph_edges = {} graph_edge_index = [[], []] - for fg_id in structured_fg_map: + for fg_id in fg_to_atoms_map: graph_edge_index[0].append(self._num_of_nodes) graph_edge_index[1].append(fg_id) fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { From 65dfda649743cc4f6b5631c0efdffe1e610b42e8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 16:23:59 +0200 Subject: [PATCH 096/224] bond prop fix --- .../properties/augmented_properties.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index ebbf455..6f37687 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -47,12 +47,16 @@ def get_property_value(self, augmented_mol: Dict) -> List: # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges]) - prop_list.extend([self.get_bond_value(bond) for bond in fg_edges]) - prop_list.extend([self.get_bond_value(bond) for bond in fg_graph_node_edges]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] + ) + + num_directed_edges = augmented_mol[self.MAIN_KEY]["num_undirected_edges"] // 2 assert ( - len(prop_list) == augmented_mol[self.MAIN_KEY]["num_edges"] - ), "Number of property values should be equal to number of edges" + len(prop_list) == num_directed_edges + ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " return prop_list From 200d2cd4a6963f4c381a925cb64153030fc8e2fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 28 May 2025 19:44:36 +0200 Subject: [PATCH 097/224] num_undir_edges key fix --- .../preprocessing/properties/augmented_properties.py | 2 +- chebai_graph/preprocessing/properties/constants.py | 1 + chebai_graph/preprocessing/reader/augmented_reader.py | 8 ++++---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 6f37687..49aa319 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -53,7 +53,7 @@ def get_property_value(self, augmented_mol: Dict) -> List: [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] ) - num_directed_edges = augmented_mol[self.MAIN_KEY]["num_undirected_edges"] // 2 + num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2 assert ( len(prop_list) == num_directed_edges ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " diff --git a/chebai_graph/preprocessing/properties/constants.py b/chebai_graph/preprocessing/properties/constants.py index 67de13a..f64e5cb 100644 --- a/chebai_graph/preprocessing/properties/constants.py +++ b/chebai_graph/preprocessing/properties/constants.py @@ -10,3 +10,4 @@ ATOM_FG_EDGE = "atom_fg_lvl" FG_GRAPHNODE_EDGE = "fg_graphNode_lvl" EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_EDGE} +NUM_EDGES = "num_undirected_edges" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index ed31432..0628bcc 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -181,7 +181,7 @@ def _read_data(self, smiles: str) -> GeomData | None: # Empty features initialized; node and edge features can be added later x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) - edge_attr = torch.zeros((augmented_molecule["edges"]["num_edges"], 0)) + edge_attr = torch.zeros((augmented_molecule["edges"][NUM_EDGES], 0)) assert ( edge_index.shape[0] == 2 @@ -297,7 +297,7 @@ def _augment_graph_structure( ATOM_FG_EDGE: atom_fg_edges, WITHIN_FG_EDGE: internal_fg_edges, FG_GRAPHNODE_EDGE: fg_to_graph_edges, - "num_undirected_edges": self._num_of_edges * 2, # Undirected edges + NUM_EDGES: self._num_of_edges * 2, # Undirected edges } return undirected_edge_index, node_info, edge_info @@ -449,12 +449,12 @@ def _construct_fg_level_structure( source_fg is not None and target_fg is not None ), "Each bond should have a fg node on both end" - internal_edge_index[0].append(source_fg) - internal_edge_index[1].append(target_fg) edge_str = f"{source_fg}_{target_fg}" if edge_str not in internal_fg_edges: # If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted. # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG + internal_edge_index[0].append(source_fg) + internal_edge_index[1].append(target_fg) internal_fg_edges[edge_str] = {EDGE_LEVEL: WITHIN_FG_EDGE} self._num_of_edges += 1 From 1a80678d0a5b9955fa99dcda3d605199b07a8e97 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 5 Jun 2025 15:23:58 +0200 Subject: [PATCH 098/224] graphpropertyreader should inherit from Data reader instead of ChemDataReader, - As no methods from ChemDataReader are resused - Also it allows avoid unnecessary creation of reader_name/tokens.txt file - As this token files are creater by encoder in graph repo --- chebai_graph/preprocessing/reader/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index fafe73c..15a131d 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -14,7 +14,7 @@ from chebai_graph.preprocessing.properties import MolecularProperty -class GraphPropertyReader(dr.ChemDataReader): +class GraphPropertyReader(dr.DataReader): COLLATOR = GraphCollator def __init__( From 5dbb516bf7585083973654f111d77cb54ec08960 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 6 Jun 2025 00:40:18 +0200 Subject: [PATCH 099/224] for ring, assigned connected atom ring_size as fg --- chebai_graph/preprocessing/reader/augmented_reader.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 0628bcc..bd6c0e0 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -391,16 +391,22 @@ def _construct_fg_to_atom_structure( "FG": f"RING_{ring_size}", "RING": ring_size, } - else: # No connected has a ring size which indicates it is simple FG + # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group + for atom_idx in fg_group["atom"]: + mol.GetAtomWithIdx(atom_idx).SetProp("FG", f"RING_{ring_size}") + + else: # No connected atoms have a ring size which indicates it is simple FG fg_set = {mol.GetAtomWithIdx(i).GetProp("FG") for i in fg_group["atom"]} if "" in fg_set and len(fg_set) == 1: + # TODO: Check how GraphReader handles the wildcard smiles case # There will be no FGs for wildcard SMILES Eg. CHEBI:33429 return None if "" in fg_set or len(fg_set) > 1: raise ValueError("Invalid functional group assignment to atoms.") + # Select any one connected atom to get FG type and ring size for atom_idx in fg_group["atom"]: atom = mol.GetAtomWithIdx(atom_idx) if atom.GetProp("FG"): From dfb6b9b04a94af00a99ac0b97d6c39faeb97f5bb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 6 Jun 2025 15:50:35 +0200 Subject: [PATCH 100/224] make fg graph code more efficient --- .../preprocessing/reader/augmented_reader.py | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index bd6c0e0..84e6545 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -361,6 +361,8 @@ def _construct_fg_to_atom_structure( for fg_group in structure.values(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} + ring_fg = set() + connected_atoms = [] # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: fg_atom_edge_index[0].append(self._num_of_nodes) @@ -370,12 +372,13 @@ def _construct_fg_to_atom_structure( } self._num_of_edges += 1 - # Identify ring vs. functional group type - ring_fg = { - mol.GetAtomWithIdx(i).GetProp("RING") - for i in fg_group["atom"] - if mol.GetAtomWithIdx(i).GetProp("RING") - } + atom = mol.GetAtomWithIdx( + atom_idx + ) # reference to atom in mol is returned + connected_atoms.append(atom) + + if atom.GetProp("RING"): + ring_fg.add(atom.GetProp("RING")) if len(ring_fg) > 1: raise ValueError( @@ -392,11 +395,15 @@ def _construct_fg_to_atom_structure( "RING": ring_size, } # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group - for atom_idx in fg_group["atom"]: - mol.GetAtomWithIdx(atom_idx).SetProp("FG", f"RING_{ring_size}") + for atom in connected_atoms: + atom.SetProp("FG", f"RING_{ring_size}") else: # No connected atoms have a ring size which indicates it is simple FG - fg_set = {mol.GetAtomWithIdx(i).GetProp("FG") for i in fg_group["atom"]} + fg_set = {atom.GetProp("FG") for atom in connected_atoms} + if not fg_set: + raise ValueError( + "No functional group assigned to atoms in the functional group." + ) if "" in fg_set and len(fg_set) == 1: # TODO: Check how GraphReader handles the wildcard smiles case @@ -407,20 +414,20 @@ def _construct_fg_to_atom_structure( raise ValueError("Invalid functional group assignment to atoms.") # Select any one connected atom to get FG type and ring size - for atom_idx in fg_group["atom"]: - atom = mol.GetAtomWithIdx(atom_idx) - if atom.GetProp("FG"): - fg_nodes[self._num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, - "FG": atom.GetProp("FG"), - "RING": atom.GetProp("RING"), - } - break - else: + representative_atom = next( + (atom for atom in connected_atoms if atom.GetProp("FG")), None + ) + if representative_atom is None: raise AssertionError( "Expected at least one atom with a functional group." ) + fg_nodes[self._num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": representative_atom.GetProp("FG"), + "RING": representative_atom.GetProp("RING"), + } + self._num_of_nodes += 1 return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds From c0739ada58558b9c00da11dd54e0bc74e4214637 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 13:03:45 +0200 Subject: [PATCH 101/224] add plot option to plot only rdkit molecule using rdkit --- .../utils/visualize_augmented_molecule.py | 82 +++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 8224d17..aa0b62a 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -1,8 +1,12 @@ +import io + import matplotlib import matplotlib.pyplot as plt import networkx as nx from jsonargparse import CLI -from rdkit.Chem import AllChem, Mol +from PIL import Image +from rdkit.Chem import BondType, Mol, rdDepictor +from rdkit.Chem.Draw import rdMolDraw2D from torch import Tensor from chebai_graph.preprocessing.properties.constants import * @@ -24,6 +28,17 @@ } +BOND_COLOR_MAP = { + BondType.SINGLE: "black", + BondType.DOUBLE: "blue", + BondType.TRIPLE: "green", + BondType.DATIVE: "red", + BondType.DATIVEL: "red", + BondType.DATIVER: "red", + BondType.DATIVEONE: "red", +} + + def _create_graph( edge_index: Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict ) -> nx.Graph: @@ -91,7 +106,7 @@ def _create_graph( elif undirected_edge_set & fg_graph_edges: edge_type = FG_GRAPHNODE_EDGE else: - raise Exception("Unexpected edge type") + raise ValueError("Unexpected edge type") G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) return G @@ -202,10 +217,10 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: """ try: from plotly import graph_objects as go - except ImportError: + except ImportError as e: raise ImportError( "Plotly is required for 3D plotting. Install it with `pip install plotly`." - ) + ) from e # Generate 3D coordinates for atoms AllChem.EmbedMolecule(mol) @@ -322,6 +337,54 @@ def plot_augmented_graph( raise ValueError(f"Unknown plot type: {plot_type}") +def plot_nonaugment_molecule_graph(mol: Mol, size=(800, 800)) -> None: + """ + Visualize a molecule using rdkit. + """ + + print(f"Number of atoms: {mol.GetNumAtoms()}") + print(f"Number of bonds: {mol.GetNumBonds()}") + print("\nAtoms:") + for atom in mol.GetAtoms(): + print(f"\tAtom index: {atom.GetIdx()}, Symbol: {atom.GetSymbol()}") + + print("\nBonds:") + for bond in mol.GetBonds(): + a1 = bond.GetBeginAtomIdx() + a2 = bond.GetEndAtomIdx() + btype = bond.GetBondType() + print(f"\tBond index: {bond.GetIdx()}, Atoms: ({a1}, {a2}), Type: {btype}") + + # Generate 2D coordinates + rdDepictor.Compute2DCoords(mol) + + # Set up drawer with high resolution + drawer = rdMolDraw2D.MolDraw2DCairo(*size) + options = drawer.drawOptions() + + # Display atom indices and symbols + options.addAtomIndices = True + options.addStereoAnnotation = True + options.padding = 0.05 # Less whitespace + options.fixedBondLength = 25 # for visual clarity + + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + # Convert to image + png = drawer.GetDrawingText() + img = Image.open(io.BytesIO(png)) + + # Show using matplotlib + dpi = 300 + plt.figure(figsize=(size[0] / dpi, size[1] / dpi), dpi=dpi) + plt.imshow(img) + plt.axis("off") + plt.tight_layout(pad=0) + # plt.title("RDKit 2D Molecule with Atom Indices", fontsize=14) + plt.show() + + class Main: """ Command-line wrapper class for plotting augmented molecular graphs. @@ -342,6 +405,13 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No - 3d: Hierarchical 3D-graph """ mol = self._fg_reader._smiles_to_mol(smiles) # noqa + if mol is None: + raise ValueError(f"Invalid SMILES: {smiles}") + + if plot_type == "molecule_only": + plot_nonaugment_molecule_graph(mol) + return + edge_index, augmented_molecule = self._fg_reader._create_augmented_graph( mol ) # noqa @@ -355,4 +425,6 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No # 1-hydroxy-2-naphthoic acid -> OC(=O)c1ccc2ccccc2c1O ; CHEBI:36108 ; Fused Rings # 3-nitrobenzoic acid -> OC(=O)C1=CC(=CC=C1)[N+]([O-])=O ; CHEBI:231494 ; Ring + Novel atom (Nitrogen) # nile blue A -> [Cl-].CCN(CC)c1ccc2nc3c(cc(N)c4ccccc34)[o+]c2c1 ; CHEBI:52163 ; Fused rings + Novel atoms - CLI(Main) + # CHEBI:52723; Complicated molecule; 'O.O.[Cl-].[Cl-].C1=Cc2ccc3C=CC=[N]4c3c2[N](=C1)[Ru++]4123[N]4=CC=Cc5ccc6C=CC=[N]1c6c45.C1=Cc4ccc5C=CC=[N]2c5c4[N]3=C1' + # CHEBI:87627; Complicated molecule; "[Cl-].[Cl-].[Cl-].[Cl-].[Zn++].COc1cc(ccc1[N+]#N)-c1ccc([N+]#N)c(OC)c1" + CLI(Main, as_positional=False) From e4145a0a45b323a2e5dd8251cb87a29061b1c634 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 13:46:00 +0200 Subject: [PATCH 102/224] fix for atoms with no functional group --- .../preprocessing/reader/augmented_reader.py | 82 +++++++++++-------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 84e6545..490a7eb 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -52,7 +52,6 @@ def name(cls) -> str: Returns: str: Name of the augmentor. """ - pass @abstractmethod def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, Dict]: @@ -65,7 +64,6 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, Dict]: Returns: Tuple[torch.Tensor, Dict]: Graph edge index and augmented molecule information """ - pass @abstractmethod def _read_data(self, raw_data: str) -> GeomData: @@ -78,7 +76,6 @@ def _read_data(self, raw_data: str) -> GeomData: Returns: GeomData: `torch_geometric.data.Data` object. """ - pass def _smiles_to_mol(self, smiles: str) -> Chem.Mol: """ @@ -171,6 +168,7 @@ def _read_data(self, smiles: str) -> GeomData | None: return None returned_result = self._create_augmented_graph(mol) + # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: rank_zero_info(f"Failed to construct augmented graph for smiles {smiles}") self.f_cnt_for_aug_graph += 1 @@ -197,9 +195,7 @@ def _read_data(self, smiles: str) -> GeomData | None: return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) - def _create_augmented_graph( - self, mol: Chem.Mol - ) -> Optional[Tuple[torch.Tensor, Dict]]: + def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: """ Generates an augmented graph from a SMILES string. @@ -207,21 +203,18 @@ def _create_augmented_graph( mol (Chem.Mol): A molecule generated by RDKit. Returns: - Tuple[dict, torch.Tensor]: Augmented molecule information and edge index. + Tuple[torch.Tensor, dict]: + - Augmented graph edge index, + - Augmented graph (nodes and edges). """ - returned_result = self._augment_graph_structure(mol) - if returned_result is None: - return None - - edge_index, node_info, edge_info = returned_result - + edge_index, node_info, edge_info = self._augment_graph_structure(mol) augmented_molecule = {"nodes": node_info, "edges": edge_info} return edge_index, augmented_molecule def _augment_graph_structure( self, mol: Chem.Mol - ) -> Optional[Tuple[torch.Tensor, dict, dict]]: + ) -> Tuple[torch.Tensor, dict, dict]: """ Constructs the full augmented graph structure from a molecule. @@ -229,7 +222,10 @@ def _augment_graph_structure( mol (Chem.Mol): RDKit molecule object. Returns: - Tuple[torch.Tensor, dict, dict]: Edge index, node metadata, and edge metadata. + Tuple[torch.Tensor, dict, dict]: + - Augmented graph edge index, + - Augmented graph node attributes + - Augmented graph edge attributes. """ self._num_of_nodes = mol.GetNumAtoms() self._num_of_edges = mol.GetNumBonds() @@ -238,18 +234,14 @@ def _augment_graph_structure( atom_edge_index = self._generate_atom_level_edge_index(mol) # Create FG-level structure and edges - returned_result = self._construct_fg_to_atom_structure(mol) - - if returned_result is None: - return None - fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds = ( - returned_result + self._construct_fg_to_atom_structure(mol) ) fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( fg_to_atoms_map, bonds ) + fg_graph_edge_index, graph_node, fg_to_graph_edges = ( self._construct_fg_to_graph_node_structure(fg_to_atoms_map) ) @@ -323,7 +315,7 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: mol (Chem.Mol): RDKit molecule. Returns: - torch.Tensor: Bidirectional edge index tensor. + torch.Tensor: Directed edge index tensor. """ # We need to ensure that directed edges which form a undirected edge are adjacent to each other edge_index_list = [[], []] @@ -334,17 +326,25 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: def _construct_fg_to_atom_structure( self, mol: Chem.Mol - ) -> Optional[Tuple[List[List[int]], dict, dict, dict, list]]: + ) -> tuple[list[list[int]], dict, dict, dict, list]: """ Constructs edges between functional group (FG) nodes and atom nodes. + This method detects functional groups in the molecule and creates edges + between FG nodes and their connected atom nodes. Args: mol (Chem.Mol): RDKit molecule. Returns: - Tuple[List[List[int]], dict, dict, dict, list]: - Edge index, FG node info, FG-atom edge attributes, - structured FG mapping, and bond list. + tuple[list[list[int]], dict, dict, dict, list]: A tuple containing: + - Edge index for FG to atom connections. + - FG node info, + - FG-atom edge attributes, + - FG to atoms mapping, + - Bonds between FG nodes. + + Raises: + ValueError: If functional groups span multiple ring sizes or if no functional group is assigned to atoms. """ # Rule-based algorithm to detect functional groups @@ -358,7 +358,7 @@ def _construct_fg_to_atom_structure( # Contains augmented fg-nodes and connected atoms indices fg_to_atoms_map = {} - for fg_group in structure.values(): + for fg_smiles, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} ring_fg = set() @@ -406,11 +406,20 @@ def _construct_fg_to_atom_structure( ) if "" in fg_set and len(fg_set) == 1: - # TODO: Check how GraphReader handles the wildcard smiles case - # There will be no FGs for wildcard SMILES Eg. CHEBI:33429 - return None - - if "" in fg_set or len(fg_set) > 1: + if len(connected_atoms) == 1: + # If there is only one atom and one edge connecting this atom to its fg_atom, + # the functional group will be the symbol of this atom + # This special case is to handle wildcard SMILES Eg. CHEBI:33429 + atom = connected_atoms[0] + atom.SetProp("FG", atom.GetSymbol()) + else: + # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name + # assigned, we assign the functional group as the part of SMILES which belong to the functional group + # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used + for atom in connected_atoms: + atom.SetProp("FG", fg_smiles) + + if "" in fg_set and len(fg_set) > 1: raise ValueError("Invalid functional group assignment to atoms.") # Select any one connected atom to get FG type and ring size @@ -443,7 +452,9 @@ def _construct_fg_level_structure( bonds (list): List of bond tuples (source, target, ...). Returns: - Tuple[List[List[int]], dict]: Edge index and edge attribute dictionary. + Tuple[List[List[int]], dict]: + - Edge index within fg nodes + - Edge attributes for edges within fg nodes. """ internal_fg_edges = {} internal_edge_index = [[], []] @@ -483,7 +494,10 @@ def _construct_fg_to_graph_node_structure( fg_to_atoms_map (dict): Mapping from FG ID to atom indices. Returns: - Tuple[List[List[int]], dict, dict]: Edge index, graph-level node, edge attributes. + Tuple[List[List[int]], dict, dict]: + - Graph to FG Edge index + - Graph-level node attribute + - FG to Graph Edge attributes """ graph_node = {NODE_LEVEL: GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} From ac8cb1e3275296f3bbb9982302c2688af63c93d1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 16:09:42 +0200 Subject: [PATCH 103/224] add fg summary to reader --- .../preprocessing/reader/augmented_reader.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 490a7eb..fbcfa40 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,3 +1,4 @@ +import textwrap from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple @@ -143,6 +144,24 @@ class GraphFGAugmentorReader(_AugmentorReader): The FG nodes to connected to its related atoms and graph node is connected to all FG nodes. """ + def __init__(self, *args, **kwargs): + """ + Initializes the GraphFGAugmentorReader and sets up the failure counter and molecule cache. + + Args: + *args: Additional arguments passed to the parent class. + **kwargs: Additional keyword arguments passed to the parent class. + """ + super().__init__(*args, **kwargs) + # Record number of functional groups using relevant part of SMILES which belong to them, as function group name + self._cnt_fg_using_smiles = 0 + # Record number molecules with atleast one functional group using relevant part of SMILES which belong to them, as function group name + self._cnt_mol_with_fg_using_smiles = 0 + # Record number of functional groups using atom symbol as functional group name + self._cnt_fg_using_atom_symbol = 0 + # Record number molecules with atleast one functional group using atom symbol as functional group name + self._cnt_mol_with_fg_using_atom_symbol = 0 + @classmethod def name(cls) -> str: """ @@ -357,6 +376,8 @@ def _construct_fg_to_atom_structure( fg_nodes, atom_fg_edges = {}, {} # Contains augmented fg-nodes and connected atoms indices fg_to_atoms_map = {} + flag_mol_has_fg_using_smiles = False + flag_mol_has_fg_using_atom_symbol = False for fg_smiles, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} @@ -412,13 +433,18 @@ def _construct_fg_to_atom_structure( # This special case is to handle wildcard SMILES Eg. CHEBI:33429 atom = connected_atoms[0] atom.SetProp("FG", atom.GetSymbol()) + self._cnt_fg_using_atom_symbol += 1 + flag_mol_has_fg_using_atom_symbol = True else: # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name - # assigned, we assign the functional group as the part of SMILES which belong to the functional group + # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used for atom in connected_atoms: atom.SetProp("FG", fg_smiles) + self._cnt_fg_using_smiles += 1 + flag_mol_has_fg_using_smiles = True + if "" in fg_set and len(fg_set) > 1: raise ValueError("Invalid functional group assignment to atoms.") @@ -439,6 +465,11 @@ def _construct_fg_to_atom_structure( self._num_of_nodes += 1 + if flag_mol_has_fg_using_smiles: + self._cnt_mol_with_fg_using_smiles += 1 + if flag_mol_has_fg_using_atom_symbol: + self._cnt_mol_with_fg_using_atom_symbol += 1 + return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds def _construct_fg_level_structure( @@ -514,3 +545,18 @@ def _construct_fg_to_graph_node_structure( self._num_of_nodes += 1 return graph_edge_index, graph_node, fg_graph_edges + + def on_finish(self): + super().on_finish() + + summary = textwrap.dedent( + f""" + ==== Functional Group Summary ================== + - Functional groups using SMILES: {self._cnt_fg_using_smiles} + - Molecules with such groups (SMILES): {self._cnt_mol_with_fg_using_smiles} + - Functional groups using atom symbols: {self._cnt_fg_using_atom_symbol} + - Molecules with such groups (atom symbols): {self._cnt_mol_with_fg_using_atom_symbol} + ================================================ + """ + ) + rank_zero_info(summary.strip()) From f22cd8480ea5e70b0eed746e5d2903b877b58c1e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 17:10:05 +0200 Subject: [PATCH 104/224] assertion for atom belong to more than one fg --- .../preprocessing/reader/augmented_reader.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index fbcfa40..b02c7e2 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -155,11 +155,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Record number of functional groups using relevant part of SMILES which belong to them, as function group name self._cnt_fg_using_smiles = 0 - # Record number molecules with atleast one functional group using relevant part of SMILES which belong to them, as function group name + # Record number molecules with at least one functional group using relevant part of SMILES which belong to them, as function group name self._cnt_mol_with_fg_using_smiles = 0 # Record number of functional groups using atom symbol as functional group name self._cnt_fg_using_atom_symbol = 0 - # Record number molecules with atleast one functional group using atom symbol as functional group name + # Record number molecules with at least one functional group using atom symbol as functional group name self._cnt_mol_with_fg_using_atom_symbol = 0 @classmethod @@ -186,7 +186,13 @@ def _read_data(self, smiles: str) -> GeomData | None: if mol is None: return None - returned_result = self._create_augmented_graph(mol) + try: + returned_result = self._create_augmented_graph(mol) + except Exception as e: + raise RuntimeError( + f"Error has occurred for following SMILES: {smiles}\n\t {e}" + ) from e + # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: rank_zero_info(f"Failed to construct augmented graph for smiles {smiles}") @@ -379,6 +385,7 @@ def _construct_fg_to_atom_structure( flag_mol_has_fg_using_smiles = False flag_mol_has_fg_using_atom_symbol = False + molecule_atoms_set = set() for fg_smiles, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} @@ -386,6 +393,12 @@ def _construct_fg_to_atom_structure( connected_atoms = [] # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: + if atom_idx in molecule_atoms_set: + raise ValueError( + f"An atom {atom_idx} cannot belong to more than one functional group" + ) + molecule_atoms_set.add(atom_idx) + fg_atom_edge_index[0].append(self._num_of_nodes) fg_atom_edge_index[1].append(atom_idx) atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { From 08b7e7ddb735ebcf42332f904b97466dd8bf3c70 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 22:55:49 +0200 Subject: [PATCH 105/224] no default values, need to explicitly provided --- chebai_graph/models/_gat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py index da61235..5dbed73 100644 --- a/chebai_graph/models/_gat.py +++ b/chebai_graph/models/_gat.py @@ -17,9 +17,9 @@ def __init__(self, config: dict, **kwargs): self._dropout_rate = float(config.pop("dropout_rate", 0.1)) self._n_conv_layers = int(config.pop("n_conv_layers", 3)) self._n_linear_layers = int(config.pop("n_linear_layers", 3)) - self._n_atom_properties = int(config.pop("n_atom_properties", 0)) - self._n_bond_properties = int(config.pop("n_bond_properties", 0)) - self._n_molecule_properties = int(config.pop("n_molecule_properties", 0)) + self._n_atom_properties = int(config.pop("n_atom_properties")) + self._n_bond_properties = int(config.pop("n_bond_properties")) + self._n_molecule_properties = int(config.pop("n_molecule_properties")) self._gat = GAT( in_channels=self._n_atom_properties, hidden_channels=self._hidden_length, From 42e1b93eb161b4a736d7dae13de6ac0e268fa9d5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 23:18:04 +0200 Subject: [PATCH 106/224] rectify assertion for multiple fg check --- chebai_graph/preprocessing/reader/augmented_reader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index b02c7e2..4ddb8ec 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -458,8 +458,11 @@ def _construct_fg_to_atom_structure( self._cnt_fg_using_smiles += 1 flag_mol_has_fg_using_smiles = True - if "" in fg_set and len(fg_set) > 1: - raise ValueError("Invalid functional group assignment to atoms.") + if len(fg_set - {""}) > 1: + raise ValueError( + "Connected atoms have different function groups assigned.\n" + "All Connected atoms must belong to one functional group or None" + ) # Select any one connected atom to get FG type and ring size representative_atom = next( From 7d6608d8999942923265b01990e52c24883285af Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 7 Jun 2025 23:56:33 +0200 Subject: [PATCH 107/224] common token for atoms with no FG - instead of atom symbol or fg smiles to limit the number of FG tokens --- .../preprocessing/reader/augmented_reader.py | 74 ++++++++++--------- 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 4ddb8ec..f8077bb 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -153,14 +153,10 @@ def __init__(self, *args, **kwargs): **kwargs: Additional keyword arguments passed to the parent class. """ super().__init__(*args, **kwargs) - # Record number of functional groups using relevant part of SMILES which belong to them, as function group name - self._cnt_fg_using_smiles = 0 - # Record number molecules with at least one functional group using relevant part of SMILES which belong to them, as function group name - self._cnt_mol_with_fg_using_smiles = 0 - # Record number of functional groups using atom symbol as functional group name - self._cnt_fg_using_atom_symbol = 0 - # Record number molecules with at least one functional group using atom symbol as functional group name - self._cnt_mol_with_fg_using_atom_symbol = 0 + # Record number of functional groups using default fg token as functional group name + self._num_mol_using_default_fg_token = 0 + # Record number molecules with at least one functional group using default fg token as function group name + self._num_of_fg_using_default_fg_token = 0 @classmethod def name(cls) -> str: @@ -382,8 +378,8 @@ def _construct_fg_to_atom_structure( fg_nodes, atom_fg_edges = {}, {} # Contains augmented fg-nodes and connected atoms indices fg_to_atoms_map = {} - flag_mol_has_fg_using_smiles = False - flag_mol_has_fg_using_atom_symbol = False + + flag_mol_has_used_default_fg = False molecule_atoms_set = set() for fg_smiles, fg_group in structure.items(): @@ -440,23 +436,35 @@ def _construct_fg_to_atom_structure( ) if "" in fg_set and len(fg_set) == 1: - if len(connected_atoms) == 1: - # If there is only one atom and one edge connecting this atom to its fg_atom, - # the functional group will be the symbol of this atom - # This special case is to handle wildcard SMILES Eg. CHEBI:33429 - atom = connected_atoms[0] - atom.SetProp("FG", atom.GetSymbol()) - self._cnt_fg_using_atom_symbol += 1 - flag_mol_has_fg_using_atom_symbol = True - else: - # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name - # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group - # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used - for atom in connected_atoms: - atom.SetProp("FG", fg_smiles) - - self._cnt_fg_using_smiles += 1 - flag_mol_has_fg_using_smiles = True + # ------ Commented for Future Reference: Leads to large number of functional group tokens ------------------ + # ------ 2533 unique FG tokens which blow up the file size of corresponding property file up 183 GB + # ------ If uncommented in future, retrieve rest of the code related to this block from + # ------ https://github.com/ChEB-AI/python-chebai-graph/pull/2/commits/ac8cb1e3275296f3bbb9982302c2688af63c93d1 + # if len(connected_atoms) == 1: + # # If there is only one atom and one edge connecting this atom to its fg_atom, + # # the functional group will be the symbol of this atom + # # This special case is to handle wildcard SMILES Eg. CHEBI:33429 + # atom = connected_atoms[0] + # atom.SetProp("FG", atom.GetSymbol()) + # self._cnt_fg_using_atom_symbol += 1 + # flag_mol_has_fg_using_atom_symbol = True + # else: + # # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name + # # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group + # # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used + # for atom in connected_atoms: + # atom.SetProp("FG", fg_smiles) + # + # self._cnt_fg_using_smiles += 1 + # flag_mol_has_fg_using_smiles = True + # ----------------------------------------------------------------------------- + + # Instead assign atoms with no FG to a common token to limit the number of unique FG tokens + for atom in connected_atoms: + atom.SetProp("FG", "No_FG_assigned") + + self._num_of_fg_using_default_fg_token += 1 + flag_mol_has_used_default_fg = True if len(fg_set - {""}) > 1: raise ValueError( @@ -481,10 +489,8 @@ def _construct_fg_to_atom_structure( self._num_of_nodes += 1 - if flag_mol_has_fg_using_smiles: - self._cnt_mol_with_fg_using_smiles += 1 - if flag_mol_has_fg_using_atom_symbol: - self._cnt_mol_with_fg_using_atom_symbol += 1 + if flag_mol_has_used_default_fg: + self._num_mol_using_default_fg_token += 1 return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds @@ -568,10 +574,8 @@ def on_finish(self): summary = textwrap.dedent( f""" ==== Functional Group Summary ================== - - Functional groups using SMILES: {self._cnt_fg_using_smiles} - - Molecules with such groups (SMILES): {self._cnt_mol_with_fg_using_smiles} - - Functional groups using atom symbols: {self._cnt_fg_using_atom_symbol} - - Molecules with such groups (atom symbols): {self._cnt_mol_with_fg_using_atom_symbol} + - Functional groups using default FG token: {self._num_of_fg_using_default_fg_token} + - Molecules with at least one FG using default token: {self._num_mol_using_default_fg_token} ================================================ """ ) From d372c924dd440ed8ede8a616ffb09ae91f6f5f27 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 19 Jun 2025 20:09:53 +0200 Subject: [PATCH 108/224] fix empty node error --- .../preprocessing/utils/visualize_augmented_molecule.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index aa0b62a..1b1aa7d 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -5,7 +5,7 @@ import networkx as nx from jsonargparse import CLI from PIL import Image -from rdkit.Chem import BondType, Mol, rdDepictor +from rdkit.Chem import AllChem, BondType, Mol, rdDepictor from rdkit.Chem.Draw import rdMolDraw2D from torch import Tensor @@ -77,7 +77,7 @@ def _create_graph( ) # Add special graph node - graph_node_idx = augmented_graph_nodes["num_nodes"] + graph_node_idx = augmented_graph_nodes["num_nodes"] - 1 G.add_node( graph_node_idx, node_name="Graph Node", @@ -364,6 +364,8 @@ def plot_nonaugment_molecule_graph(mol: Mol, size=(800, 800)) -> None: # Display atom indices and symbols options.addAtomIndices = True + # Show bond indices + options.addBondIndices = True options.addStereoAnnotation = True options.padding = 0.05 # Less whitespace options.fixedBondLength = 25 # for visual clarity From 64b808dd769f0bee65ee0c154a5df11986bd22f1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 19 Jun 2025 20:10:41 +0200 Subject: [PATCH 109/224] print statemetns instead rank in pre-processing stage --- .../preprocessing/reader/augmented_reader.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index f8077bb..acf0678 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -4,7 +4,6 @@ import torch from chebai.preprocessing.reader import DataReader -from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from rdkit import Chem from torch_geometric.data import Data as GeomData @@ -90,13 +89,13 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol: """ mol = Chem.MolFromSmiles(smiles) if mol is None: - rank_zero_warn(f"RDKit failed to parse {smiles} (returned None)") + print(f"RDKit failed to parse {smiles} (returned None)") self.f_cnt_for_smiles += 1 else: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"RDKit failed at sanitizing {smiles}, Error {e}") + print(f"RDKit failed at sanitizing {smiles}, Error {e}") self.f_cnt_for_smiles += 1 return mol @@ -104,8 +103,8 @@ def on_finish(self) -> None: """ Finalizes the reading process and logs the number of failed SMILES and failed augmentation. """ - rank_zero_info(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") - rank_zero_info( + print(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") + print( f"Failed to construct augmented graph for {self.f_cnt_for_aug_graph} number of SMILES" ) self.mol_object_buffer = {} @@ -191,7 +190,7 @@ def _read_data(self, smiles: str) -> GeomData | None: # If the returned result is None, it indicates that the graph augmentation failed if returned_result is None: - rank_zero_info(f"Failed to construct augmented graph for smiles {smiles}") + print(f"Failed to construct augmented graph for smiles {smiles}") self.f_cnt_for_aug_graph += 1 return None @@ -579,4 +578,4 @@ def on_finish(self): ================================================ """ ) - rank_zero_info(summary.strip()) + print(summary.strip()) From 20338b543d7e982f7798fa0f6df40f12f088830f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 19 Jun 2025 20:19:38 +0200 Subject: [PATCH 110/224] Revert "common token for atoms with no FG" This reverts commit 7d6608d8999942923265b01990e52c24883285af. --- .../preprocessing/reader/augmented_reader.py | 74 +++++++++---------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index acf0678..b968ef6 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -152,10 +152,14 @@ def __init__(self, *args, **kwargs): **kwargs: Additional keyword arguments passed to the parent class. """ super().__init__(*args, **kwargs) - # Record number of functional groups using default fg token as functional group name - self._num_mol_using_default_fg_token = 0 - # Record number molecules with at least one functional group using default fg token as function group name - self._num_of_fg_using_default_fg_token = 0 + # Record number of functional groups using relevant part of SMILES which belong to them, as function group name + self._cnt_fg_using_smiles = 0 + # Record number molecules with at least one functional group using relevant part of SMILES which belong to them, as function group name + self._cnt_mol_with_fg_using_smiles = 0 + # Record number of functional groups using atom symbol as functional group name + self._cnt_fg_using_atom_symbol = 0 + # Record number molecules with at least one functional group using atom symbol as functional group name + self._cnt_mol_with_fg_using_atom_symbol = 0 @classmethod def name(cls) -> str: @@ -377,8 +381,8 @@ def _construct_fg_to_atom_structure( fg_nodes, atom_fg_edges = {}, {} # Contains augmented fg-nodes and connected atoms indices fg_to_atoms_map = {} - - flag_mol_has_used_default_fg = False + flag_mol_has_fg_using_smiles = False + flag_mol_has_fg_using_atom_symbol = False molecule_atoms_set = set() for fg_smiles, fg_group in structure.items(): @@ -435,35 +439,23 @@ def _construct_fg_to_atom_structure( ) if "" in fg_set and len(fg_set) == 1: - # ------ Commented for Future Reference: Leads to large number of functional group tokens ------------------ - # ------ 2533 unique FG tokens which blow up the file size of corresponding property file up 183 GB - # ------ If uncommented in future, retrieve rest of the code related to this block from - # ------ https://github.com/ChEB-AI/python-chebai-graph/pull/2/commits/ac8cb1e3275296f3bbb9982302c2688af63c93d1 - # if len(connected_atoms) == 1: - # # If there is only one atom and one edge connecting this atom to its fg_atom, - # # the functional group will be the symbol of this atom - # # This special case is to handle wildcard SMILES Eg. CHEBI:33429 - # atom = connected_atoms[0] - # atom.SetProp("FG", atom.GetSymbol()) - # self._cnt_fg_using_atom_symbol += 1 - # flag_mol_has_fg_using_atom_symbol = True - # else: - # # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name - # # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group - # # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used - # for atom in connected_atoms: - # atom.SetProp("FG", fg_smiles) - # - # self._cnt_fg_using_smiles += 1 - # flag_mol_has_fg_using_smiles = True - # ----------------------------------------------------------------------------- - - # Instead assign atoms with no FG to a common token to limit the number of unique FG tokens - for atom in connected_atoms: - atom.SetProp("FG", "No_FG_assigned") - - self._num_of_fg_using_default_fg_token += 1 - flag_mol_has_used_default_fg = True + if len(connected_atoms) == 1: + # If there is only one atom and one edge connecting this atom to its fg_atom, + # the functional group will be the symbol of this atom + # This special case is to handle wildcard SMILES Eg. CHEBI:33429 + atom = connected_atoms[0] + atom.SetProp("FG", atom.GetSymbol()) + self._cnt_fg_using_atom_symbol += 1 + flag_mol_has_fg_using_atom_symbol = True + else: + # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name + # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group + # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used + for atom in connected_atoms: + atom.SetProp("FG", fg_smiles) + + self._cnt_fg_using_smiles += 1 + flag_mol_has_fg_using_smiles = True if len(fg_set - {""}) > 1: raise ValueError( @@ -488,8 +480,10 @@ def _construct_fg_to_atom_structure( self._num_of_nodes += 1 - if flag_mol_has_used_default_fg: - self._num_mol_using_default_fg_token += 1 + if flag_mol_has_fg_using_smiles: + self._cnt_mol_with_fg_using_smiles += 1 + if flag_mol_has_fg_using_atom_symbol: + self._cnt_mol_with_fg_using_atom_symbol += 1 return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds @@ -573,8 +567,10 @@ def on_finish(self): summary = textwrap.dedent( f""" ==== Functional Group Summary ================== - - Functional groups using default FG token: {self._num_of_fg_using_default_fg_token} - - Molecules with at least one FG using default token: {self._num_mol_using_default_fg_token} + - Functional groups using SMILES: {self._cnt_fg_using_smiles} + - Molecules with such groups (SMILES): {self._cnt_mol_with_fg_using_smiles} + - Functional groups using atom symbols: {self._cnt_fg_using_atom_symbol} + - Molecules with such groups (atom symbols): {self._cnt_mol_with_fg_using_atom_symbol} ================================================ """ ) From f4829ef2229dd4ca28f24720cfba56629f10df34 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 20 Jun 2025 18:57:40 +0200 Subject: [PATCH 111/224] separate out logic of setting fg prop into diff funcs --- .../preprocessing/reader/augmented_reader.py | 167 ++++++++---------- 1 file changed, 70 insertions(+), 97 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index b968ef6..52e34a0 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -143,24 +143,6 @@ class GraphFGAugmentorReader(_AugmentorReader): The FG nodes to connected to its related atoms and graph node is connected to all FG nodes. """ - def __init__(self, *args, **kwargs): - """ - Initializes the GraphFGAugmentorReader and sets up the failure counter and molecule cache. - - Args: - *args: Additional arguments passed to the parent class. - **kwargs: Additional keyword arguments passed to the parent class. - """ - super().__init__(*args, **kwargs) - # Record number of functional groups using relevant part of SMILES which belong to them, as function group name - self._cnt_fg_using_smiles = 0 - # Record number molecules with at least one functional group using relevant part of SMILES which belong to them, as function group name - self._cnt_mol_with_fg_using_smiles = 0 - # Record number of functional groups using atom symbol as functional group name - self._cnt_fg_using_atom_symbol = 0 - # Record number molecules with at least one functional group using atom symbol as functional group name - self._cnt_mol_with_fg_using_atom_symbol = 0 - @classmethod def name(cls) -> str: """ @@ -381,14 +363,11 @@ def _construct_fg_to_atom_structure( fg_nodes, atom_fg_edges = {}, {} # Contains augmented fg-nodes and connected atoms indices fg_to_atoms_map = {} - flag_mol_has_fg_using_smiles = False - flag_mol_has_fg_using_atom_symbol = False molecule_atoms_set = set() - for fg_smiles, fg_group in structure.items(): + for _, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} - ring_fg = set() connected_atoms = [] # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: @@ -405,88 +384,82 @@ def _construct_fg_to_atom_structure( } self._num_of_edges += 1 - atom = mol.GetAtomWithIdx( - atom_idx - ) # reference to atom in mol is returned + atom = mol.GetAtomWithIdx(atom_idx) connected_atoms.append(atom) - if atom.GetProp("RING"): - ring_fg.add(atom.GetProp("RING")) - - if len(ring_fg) > 1: - raise ValueError( - "A functional group must not span multiple ring sizes." - ) - - if len(ring_fg) == 1: - # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings - ring_size = next(iter(ring_fg)) - fg_nodes[self._num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, - # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring - "FG": f"RING_{ring_size}", - "RING": ring_size, - } - # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group - for atom in connected_atoms: - atom.SetProp("FG", f"RING_{ring_size}") - - else: # No connected atoms have a ring size which indicates it is simple FG - fg_set = {atom.GetProp("FG") for atom in connected_atoms} - if not fg_set: - raise ValueError( - "No functional group assigned to atoms in the functional group." - ) - - if "" in fg_set and len(fg_set) == 1: - if len(connected_atoms) == 1: - # If there is only one atom and one edge connecting this atom to its fg_atom, - # the functional group will be the symbol of this atom - # This special case is to handle wildcard SMILES Eg. CHEBI:33429 - atom = connected_atoms[0] - atom.SetProp("FG", atom.GetSymbol()) - self._cnt_fg_using_atom_symbol += 1 - flag_mol_has_fg_using_atom_symbol = True - else: - # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name - # assigned, we assign the functional group as the relevant part of SMILES which belong to the functional group - # Eg. CHEBI:55388, atom idx 2 and 3 have no functional group name, so "[C-]#[C-]" is used - for atom in connected_atoms: - atom.SetProp("FG", fg_smiles) - - self._cnt_fg_using_smiles += 1 - flag_mol_has_fg_using_smiles = True - - if len(fg_set - {""}) > 1: - raise ValueError( - "Connected atoms have different function groups assigned.\n" - "All Connected atoms must belong to one functional group or None" - ) - - # Select any one connected atom to get FG type and ring size - representative_atom = next( - (atom for atom in connected_atoms if atom.GetProp("FG")), None - ) - if representative_atom is None: - raise AssertionError( - "Expected at least one atom with a functional group." - ) - - fg_nodes[self._num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, - "FG": representative_atom.GetProp("FG"), - "RING": representative_atom.GetProp("RING"), - } + if fg_group["is_ring_fg"]: + self._set_ring_fg_prop(connected_atoms, fg_nodes) + else: + self._set_fg_prop(connected_atoms, fg_nodes) self._num_of_nodes += 1 - if flag_mol_has_fg_using_smiles: - self._cnt_mol_with_fg_using_smiles += 1 - if flag_mol_has_fg_using_atom_symbol: - self._cnt_mol_with_fg_using_atom_symbol += 1 - return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds + def _set_ring_fg_prop(self, connected_atoms, fg_nodes): + ring_fg = { + atom.GetProp("RING") for atom in connected_atoms if atom.GetProp("RING") + } + if len(ring_fg) > 1: + raise ValueError("A functional group must not span multiple ring sizes.") + if len(ring_fg) == 0: + raise ValueError( + "At least one connected atoms must have a `RING` property set." + ) + + # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings + ring_size = next(iter(ring_fg)) + fg_nodes[self._num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring + "FG": f"RING_{ring_size}", + "RING": ring_size, + } + # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group + for atom in connected_atoms: + atom.SetProp("FG", f"RING_{ring_size}") + + def _set_fg_prop(self, connected_atoms, fg_nodes): + fg_set = {atom.GetProp("FG") for atom in connected_atoms} + if not fg_set: + raise ValueError( + "No functional group assigned to atoms in the functional group." + ) + + if "" in fg_set and len(fg_set) == 1: + if len(connected_atoms) == 1: + # If there is only one atom and one edge connecting this atom to its fg_atom, + # the functional group will be the symbol of this atom + # This special case is to handle wildcard SMILES Eg. CHEBI:33429 + atom = connected_atoms[0] + # TODO: needed or can we set to default fg prop `NO_FG`? + atom.SetProp("FG", atom.GetSymbol()) + else: + # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name + # assigned, Eg. CHEBI:55388, atom idx 2 and 3 ([C-]#[C-]") have no functional group name, so default FG prop is used + for atom in connected_atoms: + atom.SetProp("FG", "NO_FG") + # atom.SetProp("FG", fg_smiles) + + if len(fg_set - {""}) > 1: + raise ValueError( + "Connected atoms have different function groups assigned.\n" + "All Connected atoms must belong to one functional group or None" + ) + + # Select any one connected atom to get FG type and ring size + representative_atom = next( + (atom for atom in connected_atoms if atom.GetProp("FG")), None + ) + if representative_atom is None: + raise AssertionError("Expected at least one atom with a functional group.") + + fg_nodes[self._num_of_nodes] = { + NODE_LEVEL: FG_NODE_LEVEL, + "FG": representative_atom.GetProp("FG"), + "RING": 0, + } + def _construct_fg_level_structure( self, fg_to_atoms_map: dict, bonds: list ) -> Tuple[List[List[int]], dict]: From 5e9a133aae2f1238ee1a2f9b1da14394ed8c8a01 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 20 Jun 2025 20:56:26 +0200 Subject: [PATCH 112/224] separate frag for each ring in fused ring --- .../fg_detection/fg_aware_rule_based.py | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index 186ff70..2908e60 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -34,7 +34,7 @@ def find_connected_rings(ring, remaining_rings): def detect_functional_group(mol): # type: ignore - AllChem.GetSymmSSSR(mol) # type: ignore + AllChem.GetSymmSSSR(mol) if mol is not None: for atom in mol.GetAtoms(): @@ -1869,6 +1869,7 @@ def find_atom_map(smiles): def get_structure(mol): rings = mol.GetRingInfo().AtomRings() + fused_rings_groups: list[list[set[int]]] = _get_fused_rings_group(mol) splitting_bonds = set() for bond in mol.GetBonds(): @@ -1916,14 +1917,58 @@ def get_structure(mol): structure = {} for frag in smiles: - atom_idx, neighbor_idx = set(), set() - atom_idx = find_atom_map(frag) - neighbor_idx = find_neighbor_map(frag) - structure[frag] = {"atom": atom_idx, "neighbor": neighbor_idx} + atom_idx: set = find_atom_map(frag) + structure[frag] = {"atom": atom_idx, "is_ring_fg": False} + + # Convert fragment SMILES back to mol to match with fused ring atom indices + frag_mol = Chem.MolFromSmiles(frag) + frag_rings = frag_mol.GetRingInfo().AtomRings() + if len(frag_rings) >= 1: + structure[frag]["is_ring_fg"] = True + + if len(frag_rings) <= 1: + continue + + if not fused_rings_groups: + continue + + for group in fused_rings_groups: + flat_atoms = set().union(*group) + if flat_atoms.issubset(atom_idx) and len(flat_atoms) == len(atom_idx): + for idx, ring_atoms in enumerate(group): + structure[f"{frag}_{idx+1}"] = { + "atom": ring_atoms, + "is_ring_fg": True, + } + structure.pop(frag) + break return structure, BONDS +def _get_fused_rings_group(mol: Chem.Mol) -> list[list[set[int]]]: + rings = mol.GetRingInfo().AtomRings() + + fused_ring_groups = [] + visited = set() + + for i, ring1 in enumerate(rings): + if i in visited: + continue + fused_group = [set(ring1)] + visited.add(i) + for j, ring2 in enumerate(rings): + if j in visited or i == j: + continue + if len(set(ring1) & set(ring2)) >= 2: # At least 2 shared atoms + fused_group.append(set(ring2)) + visited.add(j) + if len(fused_group) > 1: + fused_ring_groups.append(fused_group) + + return fused_ring_groups + + if __name__ == "__main__": from rdkit.Chem import MolFromSmiles as s2m From 698a607ff64ce5d7f47025a16a53de0144383ade Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 20 Jun 2025 20:59:23 +0200 Subject: [PATCH 113/224] max_ring_size of atom belonging to more than one ring in fused ring --- .../preprocessing/reader/augmented_reader.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 52e34a0..89d04a8 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -367,11 +367,13 @@ def _construct_fg_to_atom_structure( molecule_atoms_set = set() for _, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} + is_ring_fg = fg_group["is_ring_fg"] connected_atoms = [] # Build edge index for fg to atom nodes connections for atom_idx in fg_group["atom"]: - if atom_idx in molecule_atoms_set: + # Fused rings can have an atom which belong to more than one ring + if atom_idx in molecule_atoms_set and not is_ring_fg: raise ValueError( f"An atom {atom_idx} cannot belong to more than one functional group" ) @@ -387,7 +389,7 @@ def _construct_fg_to_atom_structure( atom = mol.GetAtomWithIdx(atom_idx) connected_atoms.append(atom) - if fg_group["is_ring_fg"]: + if is_ring_fg: self._set_ring_fg_prop(connected_atoms, fg_nodes) else: self._set_fg_prop(connected_atoms, fg_nodes) @@ -397,18 +399,8 @@ def _construct_fg_to_atom_structure( return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds def _set_ring_fg_prop(self, connected_atoms, fg_nodes): - ring_fg = { - atom.GetProp("RING") for atom in connected_atoms if atom.GetProp("RING") - } - if len(ring_fg) > 1: - raise ValueError("A functional group must not span multiple ring sizes.") - if len(ring_fg) == 0: - raise ValueError( - "At least one connected atoms must have a `RING` property set." - ) - # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings - ring_size = next(iter(ring_fg)) + ring_size = len(connected_atoms) fg_nodes[self._num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring @@ -417,7 +409,11 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): } # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group for atom in connected_atoms: - atom.SetProp("FG", f"RING_{ring_size}") + ring_prop = atom.GetProp("RING") + if not ring_prop: + raise ValueError("Atom does not have a ring size set") + max_ring_size = max(list(map(int, ring_prop.split("-")))) + atom.SetProp("FG", f"RING_{max_ring_size}") def _set_fg_prop(self, connected_atoms, fg_nodes): fg_set = {atom.GetProp("FG") for atom in connected_atoms} From 4b4179b60358d4146f1a5a281d2b96d4af55e62a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 20 Jun 2025 23:46:39 +0200 Subject: [PATCH 114/224] edges between fg of rings belonging to fused rings --- .../preprocessing/reader/augmented_reader.py | 59 ++++++++++--------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 89d04a8..42a86b1 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -366,7 +366,7 @@ def _construct_fg_to_atom_structure( molecule_atoms_set = set() for _, fg_group in structure.items(): - fg_to_atoms_map[self._num_of_nodes] = {"atom": fg_group["atom"]} + fg_to_atoms_map[self._num_of_nodes] = fg_group is_ring_fg = fg_group["is_ring_fg"] connected_atoms = [] @@ -474,21 +474,14 @@ def _construct_fg_level_structure( internal_fg_edges = {} internal_edge_index = [[], []] - for bond in bonds: - source_atom, target_atom = bond[:2] - source_fg, target_fg = None, None - - for fg_id, data in fg_to_atoms_map.items(): - if source_atom in data["atom"]: - source_fg = fg_id - if target_atom in data["atom"]: - target_fg = fg_id - + def add_fg_internal_edge(source_fg, target_fg): assert ( source_fg is not None and target_fg is not None ), "Each bond should have a fg node on both end" + assert source_fg != target_fg, "Source and Target FG should be different" - edge_str = f"{source_fg}_{target_fg}" + edge_key = tuple(sorted((source_fg, target_fg))) + edge_str = f"{edge_key[0]}_{edge_key[1]}" if edge_str not in internal_fg_edges: # If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted. # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG @@ -497,6 +490,33 @@ def _construct_fg_level_structure( internal_fg_edges[edge_str] = {EDGE_LEVEL: WITHIN_FG_EDGE} self._num_of_edges += 1 + for bond in bonds: + source_atom, target_atom = bond[:2] + source_fg, target_fg = None, None + for fg_id, data in fg_to_atoms_map.items(): + if source_fg is None and source_atom in data["atom"]: + source_fg = fg_id + if target_fg is None and target_atom in data["atom"]: + target_fg = fg_id + if source_fg is not None and target_fg is not None: + break + add_fg_internal_edge(source_fg, target_fg) + + # For Rings belonging to fused rings + fg_nodes = list(fg_to_atoms_map.keys()) + for i, fg_node_1 in enumerate(fg_nodes): + fg_map_1 = fg_to_atoms_map[fg_node_1] + for fg_node_2 in fg_nodes[i + 1 :]: + fg_map_2 = fg_to_atoms_map[fg_node_2] + if ( + (fg_node_1 == fg_node_2) + or not fg_map_1["is_ring_fg"] + or not fg_map_2["is_ring_fg"] + ): + continue + if fg_map_1["atom"] & fg_map_2["atom"]: + add_fg_internal_edge(fg_node_1, fg_node_2) + return internal_edge_index, internal_fg_edges def _construct_fg_to_graph_node_structure( @@ -529,18 +549,3 @@ def _construct_fg_to_graph_node_structure( self._num_of_nodes += 1 return graph_edge_index, graph_node, fg_graph_edges - - def on_finish(self): - super().on_finish() - - summary = textwrap.dedent( - f""" - ==== Functional Group Summary ================== - - Functional groups using SMILES: {self._cnt_fg_using_smiles} - - Molecules with such groups (SMILES): {self._cnt_mol_with_fg_using_smiles} - - Functional groups using atom symbols: {self._cnt_fg_using_atom_symbol} - - Molecules with such groups (atom symbols): {self._cnt_mol_with_fg_using_atom_symbol} - ================================================ - """ - ) - print(summary.strip()) From 0f137b85f37131fd23e02a5ac73d492710347d42 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 00:25:39 +0200 Subject: [PATCH 115/224] =?UTF-8?q?3d=20=E2=97=80=EF=B8=8F=20=20mean=20pos?= =?UTF-8?q?ition=20of=20connected=20atoms=20for=20aug=20nodes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../utils/visualize_augmented_molecule.py | 35 +++++++++++++++---- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 1b1aa7d..307585b 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -3,6 +3,7 @@ import matplotlib import matplotlib.pyplot as plt import networkx as nx +import numpy as np from jsonargparse import CLI from PIL import Image from rdkit.Chem import AllChem, BondType, Mol, rdDepictor @@ -232,14 +233,34 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: for pos in [conf.GetAtomPosition(atom.GetIdx())] } - # Generate 3D layout for FG and graph nodes - fg_graph = _get_subgraph_by_node_type(G, "fg") - fg_pos_3d = nx.spring_layout(fg_graph, seed=42, dim=3) - fg_pos = {k: (x, y, z + 2) for k, (x, y, z) in fg_pos_3d.items()} + # Dictionary to store functional group node positions + fg_pos = {} - graph_node_graph = _get_subgraph_by_node_type(G, "graph") - graph_pos_3d = nx.spring_layout(graph_node_graph, seed=123, dim=3) - graph_pos = {k: (x, y, z + 4) for k, (x, y, z) in graph_pos_3d.items()} + # Loop through each functional group node in the graph + for fg_node in _get_subgraph_by_node_type(G, "fg").nodes(): + # Get connected atom nodes (assuming edges are between fg and atom nodes) + connected_atoms = [ + nbr + for nbr in G.neighbors(fg_node) + if G.nodes[nbr].get("node_type") == "atom" + ] + + # Get the 2D positions of the connected atoms + positions = np.array([atom_pos[atom] for atom in connected_atoms]) + x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean() + fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation + + graph_node = next(iter(_get_subgraph_by_node_type(G, "graph").nodes())) + graph_pos_arr = np.array( + [ + fg_pos[nbr] + for nbr in G.neighbors(graph_node) + if G.nodes[nbr].get("node_type") == "fg" + ] + ) + graph_pos = { + graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4) + } pos = {**atom_pos, **fg_pos, **graph_pos} From adb37441c065e98a84cba739d5a24c2528173c97 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 12:12:38 +0200 Subject: [PATCH 116/224] fix huge fused ring issue --- .../fg_detection/fg_aware_rule_based.py | 3246 ++++++++--------- .../preprocessing/reader/augmented_reader.py | 9 +- 2 files changed, 1598 insertions(+), 1657 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index 2908e60..e628a52 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -19,7 +19,7 @@ def ring_size_processing(ring_size): # Function to find all rings connected to a given ring -def find_connected_rings(ring, remaining_rings): +def find_connected_rings(ring, remaining_rings) -> list[set[int]]: connected_rings = [ring] merged = True while merged: @@ -33,344 +33,326 @@ def find_connected_rings(ring, remaining_rings): return connected_rings -def detect_functional_group(mol): # type: ignore - AllChem.GetSymmSSSR(mol) +def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: - if mol is not None: - for atom in mol.GetAtoms(): - atom.SetProp("FG", "") - atom.SetProp("RING", "") + if mol is None: + return - ######## SET RING PROP ######## - # Get ring information - ring_info = mol.GetRingInfo() + AllChem.GetSymmSSSR(mol) - if ring_info.NumRings() > 0: - # Get list of atom rings - atom_rings = ring_info.AtomRings() + ######## SET RING PROP ######## + # Get ring information + ring_info = mol.GetRingInfo() - # Initialize a list to hold fused ring blocks and their sizes - fused_ring_blocks = [] - ring_sizes = [] + if ring_info.NumRings() > 0: + # Get list of atom rings + atom_rings = ring_info.AtomRings() - # Set of rings to process - remaining_rings = [set(ring) for ring in atom_rings] + # Initialize a list to hold fused ring blocks and their sizes + fused_ring_blocks = [] + ring_sizes = [] - # Process each ring block - while remaining_rings: - ring = remaining_rings.pop(0) - connected_rings = find_connected_rings(ring, remaining_rings) + # Set of rings to process + remaining_rings = [set(ring) for ring in atom_rings] + fused_rings_groups: list[list[set[int]]] = [] - # Merge all connected rings into one fused block - fused_block = set().union(*connected_rings) - fused_ring_blocks.append(sorted(fused_block)) - ring_sizes.append([len(r) for r in connected_rings]) + # Process each ring block + while remaining_rings: + ring = remaining_rings.pop(0) + connected_rings: list[set[int]] = find_connected_rings( + ring, remaining_rings + ) - # Display the fused ring blocks and their ring sizes - for i, block in enumerate(fused_ring_blocks): - rs = "-".join(str(size) for size in ring_size_processing(ring_sizes[i])) - for idx in block: - atom = mol.GetAtomWithIdx(idx) - atom.SetProp("RING", rs) + if len(connected_rings) > 1: + fused_rings_groups.append(connected_rings) + + for ring in connected_rings: + ring_size = len(ring) + for atom_idx in ring: + atom = mol.GetAtomWithIdx(atom_idx) + if not atom.HasProp("RING"): + atom.SetProp("RING", f"{ring_size}") + else: + # An atom shared across multiple rings in fused-ring will have Ring size like "5-6-6" + atom.SetProp("RING", f"{atom.GetProp('RING')}-{ring_size}") + + # Merge all connected rings into one fused block + fused_block = set().union(*connected_rings) + fused_ring_blocks.append(sorted(fused_block)) + ring_sizes.append([len(r) for r in connected_rings]) + + # Display the fused ring blocks and their ring sizes + for i, block in enumerate(fused_ring_blocks): + rs = "-".join(str(size) for size in ring_size_processing(ring_sizes[i])) + for idx in block: + atom = mol.GetAtomWithIdx(idx) + atom.SetProp("RING", rs) + + return fused_rings_groups + + +def detect_functional_group(mol: Chem.Mol): + if mol is None: + return + + ######## SET FUNCTIONAL GROUP PROP ######## + for atom in mol.GetAtoms(): + atom.SetProp("FG", "") + atom_symbol = atom.GetSymbol() + atom_neighbors = atom.GetNeighbors() + atom_num_neighbors = len(atom_neighbors) + num_H = atom.GetTotalNumHs() + in_ring = atom.IsInRing() + atom_idx = atom.GetIdx() + charge = atom.GetFormalCharge() + + ########################### Groups containing oxygen ########################### + if atom_symbol in ["C", "*"] and charge == 0: # and atom.GetProp('FG') == '': + num_O, num_X, num_C, num_N, num_S = 0, 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl", "Br", "I"]: + num_X += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "N": + num_N += 1 + if neighbor.GetSymbol() == "S": + num_S += 1 + + if ( + num_H == 1 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "tertiary_carbon") + if atom_num_neighbors == 4 and charge == 0 and atom.GetProp("FG") == "": + atom.SetProp("FG", "quaternary_carbon") + if ( + num_H == 0 + and atom_num_neighbors == 3 + and charge == 0 + and atom.GetProp("FG") == "" + and not in_ring + ): + atom.SetProp("FG", "alkene_carbon") - ######## SET FUNCTIONAL GROUP PROP ######## - for atom in mol.GetAtoms(): - atom_symbol = atom.GetSymbol() - atom_neighbors = atom.GetNeighbors() - atom_num_neighbors = len(atom_neighbors) - num_H = atom.GetTotalNumHs() - in_ring = atom.IsInRing() - atom_idx = atom.GetIdx() - charge = atom.GetFormalCharge() - - ########################### Groups containing oxygen ########################### if ( - atom_symbol in ["C", "*"] and charge == 0 - ): # and atom.GetProp('FG') == '': - num_O, num_X, num_C, num_N, num_S = 0, 0, 0, 0, 0 + num_O == 1 + and atom_symbol == "C" + and atom.GetProp("FG") + not in [ + "hemiacetal", + "hemiketal", + "acetal", + "ketal", + "orthoester", + "orthocarbonate_ester", + "carbonate_ester", + ] + ): + if num_N == 1: # Cyanate and Isocyanate + condition1, condition2 = False, False + condition3, condition4 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition3 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition4 = True + + if condition1 and condition2 and not in_ring: # Cyanate + atom.SetProp("FG", "cyanate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "cyanate") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + for C_neighbor in neighbor.GetNeighbors(): + if ( + C_neighbor.GetSymbol() in ["C", "*"] + and C_neighbor.GetIdx() != atom_idx + ): + C_neighbor.SetProp("FG", "") + + if condition3 and condition4 and not in_ring: # Isocyanate + atom.SetProp("FG", "isocyanate") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "isocyanate") + for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["F", "Cl", "Br", "I"]: - num_X += 1 if neighbor.GetSymbol() == "O": - num_O += 1 - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if neighbor.GetSymbol() == "N": - num_N += 1 - if neighbor.GetSymbol() == "S": - num_S += 1 + bond = mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()) + bondtype = bond.GetBondType() + if ( + bondtype == Chem.BondType.SINGLE + ): # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] + if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] + neighbor.SetProp("FG", "hydroxyl") + else: + for O_neighbor in neighbor.GetNeighbors(): + # if not O_neighbor.IsInRing(): + if ( + O_neighbor.GetIdx() != atom_idx + and O_neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetProp("FG") == "" + ): # Ether [COC] + neighbor.SetProp("FG", "ether") + if O_neighbor.GetSymbol() == "O": + if ( + O_neighbor.GetTotalNumHs() == 1 + ): # Hydroperoxy [C-O-O-H] + neighbor.SetProp("FG", "hydroperoxy") + O_neighbor.SetProp("FG", "hydroperoxy") + else: + neighbor.SetProp("FG", "peroxy") + O_neighbor.SetProp("FG", "peroxy") - if ( - num_H == 1 - and atom_num_neighbors == 3 - and charge == 0 - and atom.GetProp("FG") == "" - ): - atom.SetProp("FG", "tertiary_carbon") - if atom_num_neighbors == 4 and charge == 0 and atom.GetProp("FG") == "": - atom.SetProp("FG", "quaternary_carbon") - if ( - num_H == 0 - and atom_num_neighbors == 3 - and charge == 0 - and atom.GetProp("FG") == "" - and not in_ring - ): - atom.SetProp("FG", "alkene_carbon") + if ( + bondtype == Chem.BondType.DOUBLE + ): # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] + if ( + num_X == 1 and not neighbor.IsInRing() + ): # Acyl halide [C(=O)X] + atom.SetProp("FG", "haloformyl") + for neighbor_ in atom_neighbors: + if neighbor_.GetSymbol() in [ + "O", + "F", + "Cl", + "Br", + "I", + ]: + neighbor_.SetProp("FG", "haloformyl") - if ( - num_O == 1 - and atom_symbol == "C" - and atom.GetProp("FG") - not in [ - "hemiacetal", - "hemiketal", - "acetal", - "ketal", - "orthoester", - "orthocarbonate_ester", - "carbonate_ester", - ] - ): - if num_N == 1: # Cyanate and Isocyanate - condition1, condition2 = False, False - condition3, condition4 = False, False + if ( + (num_C == 1 and num_H == 1) + or num_H == 2 + and not in_ring + ): # Aldehyde [C(=O)H] + atom.SetProp("FG", "aldehyde") + neighbor.SetProp("FG", "aldehyde") + + if atom_num_neighbors == 3 and atom.GetProp("FG") not in [ + "haloformyl", + "amide", + ]: # Ketone [C(=0)C] + atom.SetProp("FG", "ketone") + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and not neighbor.IsInRing() + ): + neighbor.SetProp("FG", "ketone") + + if num_O == 2: # and atom.GetProp('FG') == '': + if atom_num_neighbors == 3: + if num_H == 0: + condition1, condition2, condition3, condition4 = ( + False, + False, + False, + False, + ) for neighbor in atom_neighbors: if ( - neighbor.GetSymbol() == "N" + neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + atom.GetIdx(), neighbor.GetIdx() ).GetBondType() - == Chem.BondType.TRIPLE + == Chem.BondType.DOUBLE and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() ): condition1 = True if ( neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + atom.GetIdx(), neighbor.GetIdx() ).GetBondType() == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + and not neighbor.IsInRing() ): condition2 = True - if ( - neighbor.GetSymbol() == "N" + neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + atom.GetIdx(), neighbor.GetIdx() ).GetBondType() - == Chem.BondType.DOUBLE + == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() ): condition3 = True if ( neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + atom.GetIdx(), neighbor.GetIdx() ).GetBondType() - == Chem.BondType.DOUBLE + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + and atom.GetProp("FG") != "carbamate" ): condition4 = True - if condition1 and condition2 and not in_ring: # Cyanate - atom.SetProp("FG", "cyanate") + if condition1 and condition2: + atom.SetProp("FG", "carboxylate") for neighbor in atom_neighbors: - neighbor.SetProp("FG", "cyanate") + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "carboxylate") + if condition1 and condition3: + atom.SetProp("FG", "carboxyl") for neighbor in atom_neighbors: if neighbor.GetSymbol() == "O": - for C_neighbor in neighbor.GetNeighbors(): - if ( - C_neighbor.GetSymbol() in ["C", "*"] - and C_neighbor.GetIdx() != atom_idx - ): - C_neighbor.SetProp("FG", "") - - if condition3 and condition4 and not in_ring: # Isocyanate - atom.SetProp("FG", "isocyanate") + neighbor.SetProp("FG", "carboxyl") + if ( + condition1 + and condition4 + and atom.GetProp("FG") + not in ["carbamate", "carbonate_ester"] + ): + atom.SetProp("FG", "ester") for neighbor in atom_neighbors: - neighbor.SetProp("FG", "isocyanate") - - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - bond = mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()) - bondtype = bond.GetBondType() - if ( - bondtype == Chem.BondType.SINGLE - ): # and not neighbor.IsInRing(): # [C-O]: Alcohol (COH) or Ether [COC] or Hydroperoxy [C-O-O-H] or Peroxide [C-O-O-C] - if neighbor.GetTotalNumHs() == 1: # Alcohol [COH] - neighbor.SetProp("FG", "hydroxyl") - else: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ester") for O_neighbor in neighbor.GetNeighbors(): - # if not O_neighbor.IsInRing(): - if ( - O_neighbor.GetIdx() != atom_idx - and O_neighbor.GetSymbol() in ["C", "*"] - and neighbor.GetProp("FG") == "" - ): # Ether [COC] - neighbor.SetProp("FG", "ether") - if O_neighbor.GetSymbol() == "O": - if ( - O_neighbor.GetTotalNumHs() == 1 - ): # Hydroperoxy [C-O-O-H] - neighbor.SetProp("FG", "hydroperoxy") - O_neighbor.SetProp("FG", "hydroperoxy") - else: - neighbor.SetProp("FG", "peroxy") - O_neighbor.SetProp("FG", "peroxy") - - if ( - bondtype == Chem.BondType.DOUBLE - ): # [C=O]: Ketone [CC(=0)C] or Aldehyde [CC(=O)H] or Acyl halide [C(=O)X] - if ( - num_X == 1 and not neighbor.IsInRing() - ): # Acyl halide [C(=O)X] - atom.SetProp("FG", "haloformyl") - for neighbor_ in atom_neighbors: - if neighbor_.GetSymbol() in [ - "O", - "F", - "Cl", - "Br", - "I", - ]: - neighbor_.SetProp("FG", "haloformyl") - - if ( - (num_C == 1 and num_H == 1) - or num_H == 2 - and not in_ring - ): # Aldehyde [C(=O)H] - atom.SetProp("FG", "aldehyde") - neighbor.SetProp("FG", "aldehyde") - - if atom_num_neighbors == 3 and atom.GetProp( - "FG" - ) not in [ - "haloformyl", - "amide", - ]: # Ketone [C(=0)C] - atom.SetProp("FG", "ketone") - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and not neighbor.IsInRing() - ): - neighbor.SetProp("FG", "ketone") - - if num_O == 2: # and atom.GetProp('FG') == '': - if atom_num_neighbors == 3: - if num_H == 0: - condition1, condition2, condition3, condition4 = ( - False, - False, - False, - False, - ) - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - and not neighbor.IsInRing() - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == -1 - and not neighbor.IsInRing() - ): - condition2 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 1 - and not neighbor.IsInRing() - ): - condition3 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 0 - and atom.GetProp("FG") != "carbamate" - ): - condition4 = True - - if condition1 and condition2: - atom.SetProp("FG", "carboxylate") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "carboxylate") - if condition1 and condition3: - atom.SetProp("FG", "carboxyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "carboxyl") - if ( - condition1 - and condition4 - and atom.GetProp("FG") - not in ["carbamate", "carbonate_ester"] - ): - atom.SetProp("FG", "ester") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "ester") - for O_neighbor in neighbor.GetNeighbors(): - O_neighbor.SetProp("FG", "ester") - - if num_H == 1 and not in_ring: - condition1, condition2 = False, False - cnt = 0 - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 1 - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 0 - ): - condition2 = True - cnt += 1 - - if condition1 and condition2: - atom.SetProp("FG", "hemiacetal") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "hemiacetal") - if cnt == 2: - atom.SetProp("FG", "acetal") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "acetal") + O_neighbor.SetProp("FG", "ester") - if atom_num_neighbors == 4 and not in_ring: + if num_H == 1 and not in_ring: condition1, condition2 = False, False cnt = 0 for neighbor in atom_neighbors: @@ -382,7 +364,6 @@ def detect_functional_group(mol): # type: ignore == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 1 - and not neighbor.IsInRing() ): condition1 = True if ( @@ -393,24 +374,24 @@ def detect_functional_group(mol): # type: ignore == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 - and not neighbor.IsInRing() ): condition2 = True cnt += 1 if condition1 and condition2: - atom.SetProp("FG", "hemiketal") + atom.SetProp("FG", "hemiacetal") for neighbor in atom_neighbors: if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "hemiketal") + neighbor.SetProp("FG", "hemiacetal") if cnt == 2: - atom.SetProp("FG", "ketal") + atom.SetProp("FG", "acetal") for neighbor in atom_neighbors: if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "ketal") + neighbor.SetProp("FG", "acetal") - if num_O == 3 and atom_num_neighbors == 4 and not in_ring: - n_C = 0 + if atom_num_neighbors == 4 and not in_ring: + condition1, condition2 = False, False + cnt = 0 for neighbor in atom_neighbors: if ( neighbor.GetSymbol() == "O" @@ -419,49 +400,10 @@ def detect_functional_group(mol): # type: ignore ).GetBondType() == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 0 - ): - n_C += 1 - if n_C == 3: - atom.SetProp("FG", "orthoester") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "orthoester") - - if ( - num_O == 3 - and atom_num_neighbors == 3 - and charge == 0 - and not in_ring - ): - condition1 = False - n_O = 0 - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() ): condition1 = True - if ( - mol.GetBondBetweenAtoms( - atom.GetIdx(), neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 0 - ): - n_O += 1 - if condition1 and n_O == 2: - atom.SetProp("FG", "carbonate_ester") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "carbonate_ester") - - if num_O == 4 and not in_ring: - n_C = 0 - for neighbor in atom_neighbors: if ( neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( @@ -470,1139 +412,846 @@ def detect_functional_group(mol): # type: ignore == Chem.BondType.SINGLE and neighbor.GetFormalCharge() == 0 and neighbor.GetTotalNumHs() == 0 - ): - n_C += 1 - if n_C == 4: - atom.SetProp("FG", "orthocarbonate_ester") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "orthocarbonate_ester") - - ########################### Groups containing nitrogen ########################### - #### Amidine #### - if num_N == 2 and atom_num_neighbors == 3: - condition1, condition2 = False, False - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "N" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetFormalCharge() == 0 - and not neighbor.IsInRing() - ): - condition1 = True - if ( - neighbor.GetSymbol() == "N" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 3 - and neighbor.GetFormalCharge() == 0 and not neighbor.IsInRing() ): condition2 = True + cnt += 1 + if condition1 and condition2: - atom.SetProp("FG", "amidine") + atom.SetProp("FG", "hemiketal") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "hemiketal") + if cnt == 2: + atom.SetProp("FG", "ketal") for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "N": - neighbor.SetProp("FG", "amidine") + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "ketal") - if num_N == 1 and num_O == 2 and atom_num_neighbors == 3: - condition1, condition2, condition3 = False, False, False + if num_O == 3 and atom_num_neighbors == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_C += 1 + if n_C == 3: + atom.SetProp("FG", "orthoester") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and len(neighbor.GetNeighbors()) == 2 - ): - condition2 = True - if ( - neighbor.GetSymbol() == "N" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - and len(neighbor.GetNeighbors()) == 3 - and not neighbor.IsInRing() - ): - condition3 = True - if condition1 and condition2 and condition3: - atom.SetProp("FG", "carbamate") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "carbamate") + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "orthoester") - if num_N == 1 and num_S == 1: - condition1, condition2 = False, False + if num_O == 3 and atom_num_neighbors == 3 and charge == 0 and not in_ring: + condition1 = False + n_O = 0 + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_O += 1 + if condition1 and n_O == 2: + atom.SetProp("FG", "carbonate_ester") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "N" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 2 - and not neighbor.IsInRing() - ): - condition1 = True - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 1 - and neighbor.GetTotalNumHs() == 0 - and not neighbor.IsInRing() - ): - condition2 = True - if condition1 and condition2: - atom.SetProp("FG", "isothiocyanate") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "isothiocyanate") + neighbor.SetProp("FG", "carbonate_ester") - if num_S == 1 and atom_num_neighbors == 3: + if num_O == 4 and not in_ring: + n_C = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom.GetIdx(), neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 0 + ): + n_C += 1 + if n_C == 4: + atom.SetProp("FG", "orthocarbonate_ester") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 1 - and neighbor.GetTotalNumHs() == 0 - and not neighbor.IsInRing() - ): - atom.SetProp("FG", "thioketone") - neighbor.SetProp("FG", "thioketone") - - if num_S == 1 and num_H == 1 and atom_num_neighbors == 2: - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 1 - and neighbor.GetTotalNumHs() == 0 - and not neighbor.IsInRing() - ): - atom.SetProp("FG", "thial") - neighbor.SetProp("FG", "thial") - - if num_S == 1 and num_O == 1 and atom_num_neighbors == 3: - condition1, condition2 = False, False - condition3, condition4 = False, False - condition5, condition6 = False, False - condition7, condition8 = False, False - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 1 - and neighbor.GetTotalNumHs() == 1 - and not neighbor.IsInRing() - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and not neighbor.IsInRing() - ): - condition2 = True - - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and not neighbor.IsInRing() - ): - condition3 = True - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetTotalNumHs() == 0 - and not len(neighbor.GetNeighbors()) == 1 - ): - condition4 = True - - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetTotalNumHs() == 0 - and not neighbor.IsInRing() - ): - flag = True - for bond in neighbor.GetBonds(): - if bond.GetBondType() != Chem.BondType.SINGLE: - flag = False - if flag: - condition5 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and not neighbor.IsInRing() - ): - condition6 = True - - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetFormalCharge() == 0 - and not neighbor.IsInRing() - ): - condition7 = True - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetTotalNumHs() == 0 - and len(neighbor.GetNeighbors()) == 1 - and not neighbor.IsInRing() - ): - condition8 = True - - if condition1 and condition2: - atom.SetProp("FG", "carbothioic_S-acid") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["S", "O"]: - neighbor.SetProp("FG", "carbothioic_S-acid") - if condition3 and condition4: - atom.SetProp("FG", "carbothioic_O-acid") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["S", "O"]: - neighbor.SetProp("FG", "carbothioic_O-acid") - if condition5 and condition6: - atom.SetProp("FG", "thiolester") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["S", "O"]: - neighbor.SetProp("FG", "thiolester") - if condition7 and condition8: - atom.SetProp("FG", "thionoester") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["S", "O"]: - neighbor.SetProp("FG", "thionoester") - - if num_S == 2 and atom_num_neighbors == 3: - condition1, condition2, condition3 = False, False, False - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and len(neighbor.GetNeighbors()) == 1 - and not neighbor.IsInRing() - ): - condition1 = True - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetTotalNumHs() == 0 - and len(neighbor.GetNeighbors()) == 1 - and not neighbor.IsInRing() - ): - condition2 = True - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), atom_idx - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 0 - and len(neighbor.GetNeighbors()) == 2 - and not neighbor.IsInRing() - ): - flag = True - for bond in neighbor.GetBonds(): - if bond.GetBondType() != Chem.BondType.SINGLE: - flag = False - if flag: - condition3 = True - - if condition1 and condition2: - atom.SetProp("FG", "carbodithioic_acid") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "S": - neighbor.SetProp("FG", "carbodithioic_acid") - if condition3 and condition2: - atom.SetProp("FG", "carbodithio") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "S": - neighbor.SetProp("FG", "carbodithio") - - if num_X == 3 and charge == 0 and atom_num_neighbors == 4: - num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "F": - num_F += 1 - if neighbor.GetSymbol() == "Cl": - num_Cl += 1 - if neighbor.GetSymbol() == "Br": - num_Br += 1 - if neighbor.GetSymbol() == "I": - num_I += 1 - if num_F == 3: - atom.SetProp("FG", "trifluoromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "F": - neighbor.SetProp("FG", "trifluoromethyl") - if num_F == 2 and num_Cl == 1: - atom.SetProp("FG", "difluorochloromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["F", "Cl"]: - neighbor.SetProp("FG", "difluorochloromethyl") - if num_F == 2 and num_Br == 1: - atom.SetProp("FG", "bromodifluoromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["F", "Br"]: - neighbor.SetProp("FG", "bromodifluoromethyl") - - if num_Cl == 3: - atom.SetProp("FG", "trichloromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "Cl": - neighbor.SetProp("FG", "trichloromethyl") - if num_Cl == 2 and num_Br == 1: - atom.SetProp("FG", "bromodichloromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["Cl", "Br"]: - neighbor.SetProp("FG", "bromodichloromethyl") - - if num_Br == 3: - atom.SetProp("FG", "tribromomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "Br": - neighbor.SetProp("FG", "tribromomethyl") - if num_Br == 2 and num_F == 1: - atom.SetProp("FG", "dibromofluoromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["F", "Br"]: - neighbor.SetProp("FG", "dibromofluoromethyl") - - if num_I == 3: - atom.SetProp("FG", "triiodomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "I": - neighbor.SetProp("FG", "triiodomethyl") - - if ( - num_X == 2 - and charge == 0 - and atom_num_neighbors == 3 - and num_H == 1 - ): - num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "F": - num_F += 1 - if neighbor.GetSymbol() == "Cl": - num_Cl += 1 - if neighbor.GetSymbol() == "Br": - num_Br += 1 - if neighbor.GetSymbol() == "I": - num_I += 1 - - if num_F == 2: - atom.SetProp("FG", "difluoromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "F": - neighbor.SetProp("FG", "difluoromethyl") - if num_F == 1 and num_Cl == 1: - atom.SetProp("FG", "fluorochloromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["F", "Cl"]: - neighbor.SetProp("FG", "fluorochloromethyl") - - if num_Cl == 2: - atom.SetProp("FG", "dichloromethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "Cl": - neighbor.SetProp("FG", "dichloromethyl") - if num_Cl == 1 and num_Br == 1: - atom.SetProp("FG", "chlorobromomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["Cl", "Br"]: - neighbor.SetProp("FG", "chlorobromomethyl") - if num_Cl == 1 and num_I == 1: - atom.SetProp("FG", "chloroiodomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["Cl", "I"]: - neighbor.SetProp("FG", "chloroiodomethyl") - - if num_Br == 2: - atom.SetProp("FG", "dibromomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "Br": - neighbor.SetProp("FG", "dibromomethyl") - if num_Br == 1 and num_I == 1: - atom.SetProp("FG", "bromoiodomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["Br", "I"]: - neighbor.SetProp("FG", "bromoiodomethyl") - - if num_I == 2: - atom.SetProp("FG", "diiodomethyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "I": - neighbor.SetProp("FG", "diiodomethyl") - - if ( - (atom_num_neighbors == 2 or atom_num_neighbors == 1) - and not in_ring - and atom.GetProp("FG") == "" - ): - bonds = atom.GetBonds() - ns, nd, nt = 0, 0, 0 - for bond in bonds: - if bond.GetBondType() == Chem.BondType.SINGLE: - ns += 1 - elif bond.GetBondType() == Chem.BondType.DOUBLE: - nd += 1 - else: - nt += 1 - if ns >= 1 and nd == 0 and nt == 0: - atom.SetProp("FG", "alkyl") - if nd >= 1: - atom.SetProp("FG", "alkene") - if nt == 1: - atom.SetProp("FG", "alkyne") - - elif ( - atom_symbol == "O" and not in_ring and charge == 0 and num_H == 0 - ): # Carboxylic anhydride [C(CO)O(CO)C] - num_C = 0 - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if num_C == 2: - cnt = 0 - for neighbor in atom_neighbors: - for C_neighbor in neighbor.GetNeighbors(): - if ( - C_neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), C_neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 3 - ): - cnt += 1 - if cnt == 2: - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "carboxylic_anhydride") - for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() == "O": - C_neighbor.SetProp("FG", "carboxylic_anhydride") - - elif atom_symbol == "N": # and atom.GetProp('FG') == '': - num_C, num_O, num_N = 0, 0, 0 - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if neighbor.GetSymbol() == "O": - num_O += 1 - if neighbor.GetSymbol() == "N": - num_N += 1 - - #### Amines #### - if ( - charge == 0 - and num_H == 2 - and atom_num_neighbors == 1 - and atom.GetProp("FG") != "hydrazone" - ): # Primary amine [RNH2] - atom.SetProp("FG", "primary_amine") + neighbor.SetProp("FG", "orthocarbonate_ester") - if ( - charge == 0 and num_H == 1 and atom_num_neighbors == 2 - ): # Secondary amine [R'R"NH] - atom.SetProp("FG", "secondary_amine") - - if ( - charge == 0 - and atom_num_neighbors == 3 - and atom.GetProp("FG") != "carbamate" - ): - cnt = 0 - C_idx = [] - for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - for C_neighbor in neighbor.GetNeighbors(): - if ( - C_neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - neighbor.GetIdx(), C_neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 3 - and neighbor.GetFormalCharge() == 0 - and atom.GetProp("FG") != "imide" - ): - atom.SetProp("FG", "amide") - neighbor.SetProp("FG", "amide") - C_neighbor.SetProp("FG", "amide") - cnt += 1 - C_idx.append(neighbor.GetIdx()) - - if cnt == 2: - for neighbor in atom_neighbors: - if neighbor.GetIdx() in C_idx: - for C_neighbor in neighbor.GetNeighbors(): - if C_neighbor.GetSymbol() in ["O", "N"]: - neighbor.SetProp("FG", "imide") - C_neighbor.SetProp("FG", "imide") - - if atom.GetProp("FG") not in [ - "imide", - "amide", - "amidine", - "carbamate", - ]: # Tertiary amine [R3N] - atom.SetProp("FG", "tertiary_amine") - - if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] - atom.SetProp("FG", "4_ammonium_ion") - - if ( - charge == 0 - and num_C == 1 - and num_N == 1 - and num_H == 0 - and atom_num_neighbors == 2 - ): # Hydrazone [R'R"CN2H2] - condition1, condition2 = False, False - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() in ["C", "*"] - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 3 - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - neighbor.GetSymbol() == "N" - and neighbor.GetTotalNumHs() == 2 - and neighbor.GetFormalCharge() == 0 - ): - condition2 = True - if condition1 and condition2: - atom.SetProp("FG", "hydrazone") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "hydrazone") - - #### Imine #### - if ( - charge == 0 - and num_C == 1 - and num_H == 1 - and num_N == 0 - and atom_num_neighbors == 1 - ): # Primary ketimine [RC(=NH)R'] - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 3 - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "primary_ketimine") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "primary_ketimine") - - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "primary_aldimine") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "primary_aldimine") - - if ( - charge == 0 - and atom_num_neighbors == 1 - and atom.GetProp("FG") not in ["thiocyanate", "cyanate"] - ): # Nitrile - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.TRIPLE - ): - atom.SetProp("FG", "nitrile") - - if ( - charge == 0 - and num_C >= 1 - and atom_num_neighbors == 2 - and atom.GetProp("FG") != "hydrazone" - ): # Secondary ketimine [RC(=NR'')R'] - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() in ["C", "*"] - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 3 - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "secondary_ketimine") - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - neighbor.SetProp("FG", "secondary_ketimine") - - if ( - neighbor.GetSymbol() in ["C", "*"] - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 1 - ): - atom.SetProp("FG", "secondary_aldimine") - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - neighbor.SetProp("FG", "secondary_aldimine") - - if ( - charge == 1 and num_N == 2 and atom_num_neighbors == 2 - ): # Azide [RN3] - condition1, condition2 = False, False - for neighbor in atom_neighbors: - if ( - neighbor.GetFormalCharge() == 0 - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - condition1 = True - if ( - neighbor.GetFormalCharge() == -1 - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - condition2 = True - if condition1 and condition2 and not in_ring: - atom.SetProp("FG", "azide") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "azide") - - if ( - charge == 0 - and num_N == 1 - and atom_num_neighbors == 2 - and not in_ring - ): # Azo [RN2R'] - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "N" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "azo") - neighbor.SetProp("FG", "azo") - break - - if ( - charge == 1 and num_O == 3 and atom_num_neighbors == 3 - ): # Nitrate [RONO2] - condition1, condition2, condition3 = False, False, False - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == -1 - ): - condition2 = True - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == 0 - ): - condition3 = True - - if condition1 and condition2 and condition3 and not in_ring: - atom.SetProp("FG", "nitrate") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "nitrate") - - if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() in ["C", "*"] - and neighbor.GetFormalCharge() == -1 - and len(neighbor.GetNeighbors()) == 1 - ): - atom.SetProp("FG", "isonitrile") - neighbor.SetProp("FG", "isonitrile") - - if ( - charge == 0 - and num_O == 2 - and atom_num_neighbors == 2 - and not in_ring - ): # Nitrite - for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 2 - ): - atom.SetProp("FG", "nitrosooxy") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "nitrosooxy") - - if ( - charge == 1 - and num_O == 2 - and atom_num_neighbors == 3 - and not in_ring - ): # Nitro compound - condition1, condition2 = False, False - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetFormalCharge() == -1 - ): - condition2 = True - if condition1 and condition2 and not in_ring: - atom.SetProp("FG", "nitro") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "nitro") + ########################### Groups containing nitrogen ########################### + #### Amidine #### + if num_N == 2 and atom_num_neighbors == 3: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "amidine") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "N": + neighbor.SetProp("FG", "amidine") - if ( - charge == 0 - and num_O == 1 - and atom_num_neighbors == 2 - and not in_ring - ): + if num_N == 1 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): + condition2 = True + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + and not neighbor.IsInRing() + ): + condition3 = True + if condition1 and condition2 and condition3: + atom.SetProp("FG", "carbamate") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): # Nitroso compound - atom.SetProp("FG", "nitroso") - neighbor.SetProp("FG", "nitroso") + neighbor.SetProp("FG", "carbamate") - if ( - charge == 0 - and num_O == 1 - and num_C == 1 - and atom_num_neighbors == 2 - ): - condition1, condition2, condition3 = False, False, False + if num_N == 1 and num_S == 1: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "isothiocyanate") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - ): - condition1 = True - if ( - neighbor.GetSymbol() in ["C", "*"] - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - condition2 = True - if ( - neighbor.GetSymbol() in ["C", "*"] - and neighbor.GetTotalNumHs() == 0 - and neighbor.GetFormalCharge() == 0 - and len(neighbor.GetNeighbors()) == 3 - ): - condition3 = True + neighbor.SetProp("FG", "isothiocyanate") - if condition1 and condition2 and not in_ring: - atom.SetProp("FG", "aldoxime") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "aldoxime") - if condition1 and condition3 and not in_ring: - atom.SetProp("FG", "ketoxime") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "ketoxime") + if num_S == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thioketone") + neighbor.SetProp("FG", "thioketone") - ########################### Groups containing sulfur ########################### - elif atom_symbol == "S" and charge == 0: - num_C, num_S, num_O = 0, 0, 0 + if num_S == 1 and num_H == 1 and atom_num_neighbors == 2: for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if neighbor.GetSymbol() == "S": - num_S += 1 - if neighbor.GetSymbol() == "O": - num_O += 1 + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + atom.SetProp("FG", "thial") + neighbor.SetProp("FG", "thial") + + if num_S == 1 and num_O == 1 and atom_num_neighbors == 3: + condition1, condition2 = False, False + condition3, condition4 = False, False + condition5, condition6 = False, False + condition7, condition8 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 1 + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): + condition2 = True - if ( - num_H == 1 - and atom_num_neighbors == 1 - and atom.GetProp("FG") - not in ["carbothioic_S-acid", "carbodithioic_acid"] - ): - neighbor = atom_neighbors[0] if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx ).GetBondType() == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and not neighbor.IsInRing() + ): + condition3 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and not len(neighbor.GetNeighbors()) == 1 ): - atom.SetProp("FG", "sulfhydryl") + condition4 = True - if ( - num_H == 0 - and atom_num_neighbors == 2 - and atom.GetProp("FG") not in ["sulfhydrylester", "carbodithio"] - ): - cnt = 0 + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 0 + and not neighbor.IsInRing() + ): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition5 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and not neighbor.IsInRing() + ): + condition6 = True + + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and not neighbor.IsInRing() + ): + condition7 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition8 = True + + if condition1 and condition2: + atom.SetProp("FG", "carbothioic_S-acid") for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - cnt += 1 - if cnt == 2: - atom.SetProp("FG", "sulfide") + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_S-acid") + if condition3 and condition4: + atom.SetProp("FG", "carbothioic_O-acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "carbothioic_O-acid") + if condition5 and condition6: + atom.SetProp("FG", "thiolester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thiolester") + if condition7 and condition8: + atom.SetProp("FG", "thionoester") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["S", "O"]: + neighbor.SetProp("FG", "thionoester") - if num_H == 0 and num_S == 1 and atom_num_neighbors == 2: - condition1, condition2 = False, False + if num_S == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition1 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 1 + and not neighbor.IsInRing() + ): + condition2 = True + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + neighbor.GetIdx(), atom_idx + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and len(neighbor.GetNeighbors()) == 2 + and not neighbor.IsInRing() + ): + flag = True + for bond in neighbor.GetBonds(): + if bond.GetBondType() != Chem.BondType.SINGLE: + flag = False + if flag: + condition3 = True + + if condition1 and condition2: + atom.SetProp("FG", "carbodithioic_acid") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "S" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and len(neighbor.GetNeighbors()) == 2 - ): - condition1 = True - if ( - neighbor.GetSymbol() != "S" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - condition2 = True - if condition1 and condition2: - atom.SetProp("FG", "disulfide") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "S": - neighbor.SetProp("FG", "disulfide") + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithioic_acid") + if condition3 and condition2: + atom.SetProp("FG", "carbodithio") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "carbodithio") - if num_H == 0 and num_O >= 1 and atom_num_neighbors == 3: - condition = False - cnt = 0 + if num_X == 3 and charge == 0 and atom_num_neighbors == 4: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + num_F += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() == "Br": + num_Br += 1 + if neighbor.GetSymbol() == "I": + num_I += 1 + if num_F == 3: + atom.SetProp("FG", "trifluoromethyl") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition = True - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - cnt += 1 - if condition and cnt == 2: - atom.SetProp("FG", "sulfinyl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "sulfinyl") + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "trifluoromethyl") + if num_F == 2 and num_Cl == 1: + atom.SetProp("FG", "difluorochloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "difluorochloromethyl") + if num_F == 2 and num_Br == 1: + atom.SetProp("FG", "bromodifluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "bromodifluoromethyl") - if num_H == 0 and num_O >= 2 and atom_num_neighbors == 4: - cnt1 = 0 + if num_Cl == 3: + atom.SetProp("FG", "trichloromethyl") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - cnt1 += 1 - if cnt1 == 2: - atom.SetProp("FG", "sulfonyl") - for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - ): - neighbor.SetProp("FG", "sulfonyl") + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "trichloromethyl") + if num_Cl == 2 and num_Br == 1: + atom.SetProp("FG", "bromodichloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "bromodichloromethyl") - if num_H == 0 and num_O == 2 and atom_num_neighbors == 3: - condition1, condition2, condition3 = False, False, False + if num_Br == 3: + atom.SetProp("FG", "tribromomethyl") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - condition2 = True - if ( - neighbor.GetSymbol() != "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - condition3 = True - if condition1 and condition2 and condition3 and not in_ring: - atom.SetProp("FG", "sulfino") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "sulfino") + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "tribromomethyl") + if num_Br == 2 and num_F == 1: + atom.SetProp("FG", "dibromofluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Br"]: + neighbor.SetProp("FG", "dibromofluoromethyl") - if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: - condition1, condition2 = False, False - cnt = 0 + if num_I == 3: + atom.SetProp("FG", "triiodomethyl") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - cnt += 1 - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - condition1 = True - if ( - neighbor.GetSymbol() != "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - condition2 = True - if condition1 and condition2 and cnt == 2 and not in_ring: - atom.SetProp("FG", "sulfonic_acid") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "sulfonic_acid") + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "triiodomethyl") + + if num_X == 2 and charge == 0 and atom_num_neighbors == 3 and num_H == 1: + num_F, num_Cl, num_Br, num_I = 0, 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + num_F += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() == "Br": + num_Br += 1 + if neighbor.GetSymbol() == "I": + num_I += 1 - if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: - condition1, condition2 = False, False - cnt = 0 + if num_F == 2: + atom.SetProp("FG", "difluoromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "F": + neighbor.SetProp("FG", "difluoromethyl") + if num_F == 1 and num_Cl == 1: + atom.SetProp("FG", "fluorochloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["F", "Cl"]: + neighbor.SetProp("FG", "fluorochloromethyl") + + if num_Cl == 2: + atom.SetProp("FG", "dichloromethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Cl": + neighbor.SetProp("FG", "dichloromethyl") + if num_Cl == 1 and num_Br == 1: + atom.SetProp("FG", "chlorobromomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "Br"]: + neighbor.SetProp("FG", "chlorobromomethyl") + if num_Cl == 1 and num_I == 1: + atom.SetProp("FG", "chloroiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Cl", "I"]: + neighbor.SetProp("FG", "chloroiodomethyl") + + if num_Br == 2: + atom.SetProp("FG", "dibromomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Br": + neighbor.SetProp("FG", "dibromomethyl") + if num_Br == 1 and num_I == 1: + atom.SetProp("FG", "bromoiodomethyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["Br", "I"]: + neighbor.SetProp("FG", "bromoiodomethyl") + + if num_I == 2: + atom.SetProp("FG", "diiodomethyl") for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "I": + neighbor.SetProp("FG", "diiodomethyl") + + if ( + (atom_num_neighbors == 2 or atom_num_neighbors == 1) + and not in_ring + and atom.GetProp("FG") == "" + ): + bonds = atom.GetBonds() + ns, nd, nt = 0, 0, 0 + for bond in bonds: + if bond.GetBondType() == Chem.BondType.SINGLE: + ns += 1 + elif bond.GetBondType() == Chem.BondType.DOUBLE: + nd += 1 + else: + nt += 1 + if ns >= 1 and nd == 0 and nt == 0: + atom.SetProp("FG", "alkyl") + if nd >= 1: + atom.SetProp("FG", "alkene") + if nt == 1: + atom.SetProp("FG", "alkyne") + + elif ( + atom_symbol == "O" and not in_ring and charge == 0 and num_H == 0 + ): # Carboxylic anhydride [C(CO)O(CO)C] + num_C = 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if num_C == 2: + cnt = 0 + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): if ( - neighbor.GetSymbol() == "O" + C_neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + neighbor.GetIdx(), C_neighbor.GetIdx() ).GetBondType() == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 ): cnt += 1 - if ( - neighbor.GetSymbol() != "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 0 - and neighbor.GetFormalCharge() == 0 - ): - condition2 = True - if condition1 and condition2 and cnt == 2: - atom.SetProp("FG", "sulfonate_ester") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "sulfonate_ester") - - if num_H == 0 and atom_num_neighbors == 2: + if cnt == 2: for neighbor in atom_neighbors: + neighbor.SetProp("FG", "carboxylic_anhydride") + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() == "O": + C_neighbor.SetProp("FG", "carboxylic_anhydride") + + elif atom_symbol == "N": # and atom.GetProp('FG') == '': + num_C, num_O, num_N = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() == "N": + num_N += 1 + + #### Amines #### + if ( + charge == 0 + and num_H == 2 + and atom_num_neighbors == 1 + and atom.GetProp("FG") != "hydrazone" + ): # Primary amine [RNH2] + atom.SetProp("FG", "primary_amine") + + if ( + charge == 0 and num_H == 1 and atom_num_neighbors == 2 + ): # Secondary amine [R'R"NH] + atom.SetProp("FG", "secondary_amine") + + if ( + charge == 0 + and atom_num_neighbors == 3 + and atom.GetProp("FG") != "carbamate" + ): + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: for C_neighbor in neighbor.GetNeighbors(): if ( - C_neighbor.GetSymbol() == "N" + C_neighbor.GetSymbol() == "O" and mol.GetBondBetweenAtoms( - C_neighbor.GetIdx(), neighbor.GetIdx() + neighbor.GetIdx(), C_neighbor.GetIdx() ).GetBondType() - == Chem.BondType.TRIPLE - and not in_ring + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + and atom.GetProp("FG") != "imide" ): - atom.SetProp("FG", "thiocyanate") - neighbor.SetProp("FG", "thiocyanate") - C_neighbor.SetProp("FG", "thiocyanate") + atom.SetProp("FG", "amide") + neighbor.SetProp("FG", "amide") + C_neighbor.SetProp("FG", "amide") + cnt += 1 + C_idx.append(neighbor.GetIdx()) + + if cnt == 2: + for neighbor in atom_neighbors: + if neighbor.GetIdx() in C_idx: + for C_neighbor in neighbor.GetNeighbors(): + if C_neighbor.GetSymbol() in ["O", "N"]: + neighbor.SetProp("FG", "imide") + C_neighbor.SetProp("FG", "imide") + + if atom.GetProp("FG") not in [ + "imide", + "amide", + "amidine", + "carbamate", + ]: # Tertiary amine [R3N] + atom.SetProp("FG", "tertiary_amine") + + if charge == 1 and atom_num_neighbors == 4: # 4° ammonium ion [R3N] + atom.SetProp("FG", "4_ammonium_ion") - ########################### Groups containing phosphorus ########################### - elif atom_symbol == "P" and not in_ring and charge == 0: - num_C, num_O = 0, 0 + if ( + charge == 0 + and num_C == 1 + and num_N == 1 + and num_H == 0 + and atom_num_neighbors == 2 + ): # Hydrazone [R'R"CN2H2] + condition1, condition2 = False, False for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if neighbor.GetSymbol() == "O": - num_O += 1 + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "N" + and neighbor.GetTotalNumHs() == 2 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "hydrazone") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "hydrazone") - if atom_num_neighbors == 3: - cnt = 0 + #### Imine #### + if ( + charge == 0 + and num_C == 1 + and num_H == 1 + and num_N == 0 + and atom_num_neighbors == 1 + ): # Primary ketimine [RC(=NH)R'] + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_ketimine") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "primary_ketimine") + + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "primary_aldimine") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "primary_aldimine") + + if ( + charge == 0 + and atom_num_neighbors == 1 + and atom.GetProp("FG") not in ["thiocyanate", "cyanate"] + ): # Nitrile + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.TRIPLE + ): + atom.SetProp("FG", "nitrile") + + if ( + charge == 0 + and num_C >= 1 + and atom_num_neighbors == 2 + and atom.GetProp("FG") != "hydrazone" + ): # Secondary ketimine [RC(=NR'')R'] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 3 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "secondary_ketimine") + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_ketimine") + + if ( + neighbor.GetSymbol() in ["C", "*"] + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 1 + ): + atom.SetProp("FG", "secondary_aldimine") + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + neighbor.SetProp("FG", "secondary_aldimine") + + if charge == 1 and num_N == 2 and atom_num_neighbors == 2: # Azide [RN3] + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetFormalCharge() == 0 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition1 = True + if ( + neighbor.GetFormalCharge() == -1 + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition2 = True + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "azide") for neighbor in atom_neighbors: - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - ): - cnt += 1 - if cnt == 3: - atom.SetProp("FG", "phosphino") + neighbor.SetProp("FG", "azide") - if num_O == 3 and atom_num_neighbors == 4: - condition1, condition2 = False, False - cnt = 0 + if ( + charge == 0 and num_N == 1 and atom_num_neighbors == 2 and not in_ring + ): # Azo [RN2R'] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "N" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "azo") + neighbor.SetProp("FG", "azo") + break + + if ( + charge == 1 and num_O == 3 and atom_num_neighbors == 3 + ): # Nitrate [RONO2] + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 + ): + condition2 = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == 0 + ): + condition3 = True + + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp("FG", "nitrate") for neighbor in atom_neighbors: + neighbor.SetProp("FG", "nitrate") + + if charge == 1 and num_C >= 1 and atom_num_neighbors == 2: # Isonitrile + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == -1 + and len(neighbor.GetNeighbors()) == 1 + ): + atom.SetProp("FG", "isonitrile") + neighbor.SetProp("FG", "isonitrile") + + if ( + charge == 0 and num_O == 2 and atom_num_neighbors == 2 and not in_ring + ): # Nitrite + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): + atom.SetProp("FG", "nitrosooxy") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "nitrosooxy") + + if ( + charge == 1 and num_O == 2 and atom_num_neighbors == 3 and not in_ring + ): # Nitro compound + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( + mol.GetBondBetweenAtoms( atom_idx, neighbor.GetIdx() ).GetBondType() == Chem.BondType.DOUBLE @@ -1610,32 +1259,172 @@ def detect_functional_group(mol): # type: ignore ): condition1 = True if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - cnt += 1 - if ( - neighbor.GetSymbol() != "O" - and mol.GetBondBetweenAtoms( + mol.GetBondBetweenAtoms( atom_idx, neighbor.GetIdx() ).GetBondType() == Chem.BondType.SINGLE + and neighbor.GetFormalCharge() == -1 ): condition2 = True - if condition1 and condition2 and cnt == 2: - atom.SetProp("FG", "phosphono") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "phosphono") + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "nitro") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "nitro") + + if charge == 0 and num_O == 1 and atom_num_neighbors == 2 and not in_ring: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): # Nitroso compound + atom.SetProp("FG", "nitroso") + neighbor.SetProp("FG", "nitroso") + + if charge == 0 and num_O == 1 and num_C == 1 and atom_num_neighbors == 2: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + ): + condition1 = True + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 3 + ): + condition3 = True + + if condition1 and condition2 and not in_ring: + atom.SetProp("FG", "aldoxime") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "aldoxime") + if condition1 and condition3 and not in_ring: + atom.SetProp("FG", "ketoxime") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "ketoxime") + + ########################### Groups containing sulfur ########################### + elif atom_symbol == "S" and charge == 0: + num_C, num_S, num_O = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "S": + num_S += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if ( + num_H == 1 + and atom_num_neighbors == 1 + and atom.GetProp("FG") + not in ["carbothioic_S-acid", "carbodithioic_acid"] + ): + neighbor = atom_neighbors[0] + if ( + mol.GetBondBetweenAtoms(atom_idx, neighbor.GetIdx()).GetBondType() + == Chem.BondType.SINGLE + ): + atom.SetProp("FG", "sulfhydryl") + + if ( + num_H == 0 + and atom_num_neighbors == 2 + and atom.GetProp("FG") not in ["sulfhydrylester", "carbodithio"] + ): + cnt = 0 + for neighbor in atom_neighbors: + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if cnt == 2: + atom.SetProp("FG", "sulfide") + + if num_H == 0 and num_S == 1 and atom_num_neighbors == 2: + condition1, condition2 = False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and len(neighbor.GetNeighbors()) == 2 + ): + condition1 = True + if ( + neighbor.GetSymbol() != "S" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2: + atom.SetProp("FG", "disulfide") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "S": + neighbor.SetProp("FG", "disulfide") + + if num_H == 0 and num_O >= 1 and atom_num_neighbors == 3: + condition = False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if condition and cnt == 2: + atom.SetProp("FG", "sulfinyl") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfinyl") - if num_O == 4 and atom_num_neighbors == 4: - condition1 = False - cnt1, cnt2 = 0, 0 + if num_H == 0 and num_O >= 2 and atom_num_neighbors == 4: + cnt1 = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + cnt1 += 1 + if cnt1 == 2: + atom.SetProp("FG", "sulfonyl") for neighbor in atom_neighbors: if ( neighbor.GetSymbol() == "O" @@ -1644,202 +1433,393 @@ def detect_functional_group(mol): # type: ignore ).GetBondType() == Chem.BondType.DOUBLE ): - condition1 = True - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - cnt1 += 1 - if ( - neighbor.GetSymbol() == "O" - and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.SINGLE - and neighbor.GetTotalNumHs() == 0 - and neighbor.GetFormalCharge() == 0 - ): - cnt2 += 1 + neighbor.SetProp("FG", "sulfonyl") - if condition1 and cnt1 == 2 and cnt2 == 1: - atom.SetProp("FG", "phosphate") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "phosphate") - if condition1 and cnt1 == 1 and cnt2 == 2: - atom.SetProp("FG", "phosphodiester") - for neighbor in atom_neighbors: - neighbor.SetProp("FG", "phosphodiester") + if num_H == 0 and num_O == 2 and atom_num_neighbors == 3: + condition1, condition2, condition3 = False, False, False + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition3 = True + if condition1 and condition2 and condition3 and not in_ring: + atom.SetProp("FG", "sulfino") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfino") - if num_O == 1 and atom_num_neighbors == 4: - condition = False - cnt = 0 + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2 and cnt == 2 and not in_ring: + atom.SetProp("FG", "sulfonic_acid") + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonic_acid") + + if num_H == 0 and num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp("FG", "sulfonate_ester") for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "sulfonate_ester") + + if num_H == 0 and atom_num_neighbors == 2: + for neighbor in atom_neighbors: + for C_neighbor in neighbor.GetNeighbors(): if ( - neighbor.GetSymbol() == "O" + C_neighbor.GetSymbol() == "N" and mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() - ).GetBondType() - == Chem.BondType.DOUBLE - and neighbor.GetFormalCharge() == 0 - ): - condition = True - if ( - mol.GetBondBetweenAtoms( - atom_idx, neighbor.GetIdx() + C_neighbor.GetIdx(), neighbor.GetIdx() ).GetBondType() - == Chem.BondType.SINGLE - ): - cnt += 1 - if condition and cnt == 3: - atom.SetProp("FG", "phosphoryl") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "phosphoryl") - - ########################### Groups containing boron ########################### - elif atom_symbol == "B" and not in_ring and charge == 0: - num_C, num_O = 0, 0 + == Chem.BondType.TRIPLE + and not in_ring + ): + atom.SetProp("FG", "thiocyanate") + neighbor.SetProp("FG", "thiocyanate") + C_neighbor.SetProp("FG", "thiocyanate") + + ########################### Groups containing phosphorus ########################### + elif atom_symbol == "P" and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if atom_num_neighbors == 3: + cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if neighbor.GetSymbol() == "O": - num_O += 1 + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if cnt == 3: + atom.SetProp("FG", "phosphino") - if num_O == 2 and atom_num_neighbors == 3: - cnt1, cnt2 = 0, 0 + if num_O == 3 and atom_num_neighbors == 4: + condition1, condition2 = False, False + cnt = 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt += 1 + if ( + neighbor.GetSymbol() != "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + condition2 = True + if condition1 and condition2 and cnt == 2: + atom.SetProp("FG", "phosphono") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and neighbor.GetTotalNumHs() == 1 - and neighbor.GetFormalCharge() == 0 - ): - cnt1 += 1 - if ( - neighbor.GetSymbol() == "O" - and neighbor.GetFormalCharge() == 0 - and len(neighbor.GetNeighbors()) == 2 - ): - cnt2 += 1 - if cnt1 == 2: - atom.SetProp("FG", "borono") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "borono") - if cnt2 == 2: - atom.SetProp("FG", "boronate") - for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - neighbor.SetProp("FG", "boronate") + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphono") + + if num_O == 4 and atom_num_neighbors == 4: + condition1 = False + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + ): + condition1 = True + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt1 += 1 + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + and neighbor.GetTotalNumHs() == 0 + and neighbor.GetFormalCharge() == 0 + ): + cnt2 += 1 - if num_O == 1 and atom_num_neighbors == 3: + if condition1 and cnt1 == 2 and cnt2 == 1: + atom.SetProp("FG", "phosphate") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and neighbor.GetFormalCharge() == 0 - ): - if neighbor.GetTotalNumHs() == 1: - atom.SetProp("FG", "borino") - neighbor.SetProp("FG", "borino") - if len(neighbor.GetNeighbors()) == 2: - atom.SetProp("FG", "borinate") - neighbor.SetProp("FG", "borinate") - - ########################### Groups containing silicon ########################### - elif atom_symbol == "Si" and not in_ring and charge == 0: - num_O, num_Cl, num_C = 0, 0, 0 + neighbor.SetProp("FG", "phosphate") + if condition1 and cnt1 == 1 and cnt2 == 2: + atom.SetProp("FG", "phosphodiester") + for neighbor in atom_neighbors: + neighbor.SetProp("FG", "phosphodiester") + + if num_O == 1 and atom_num_neighbors == 4: + condition = False + cnt = 0 for neighbor in atom_neighbors: - if neighbor.GetSymbol() == "O": - num_O += 1 - if neighbor.GetSymbol() == "Cl": - num_Cl += 1 - if neighbor.GetSymbol() in ["C", "*"]: - num_C += 1 - if num_O == 1 and charge == 0 and atom_num_neighbors == 4: + if ( + neighbor.GetSymbol() == "O" + and mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.DOUBLE + and neighbor.GetFormalCharge() == 0 + ): + condition = True + if ( + mol.GetBondBetweenAtoms( + atom_idx, neighbor.GetIdx() + ).GetBondType() + == Chem.BondType.SINGLE + ): + cnt += 1 + if condition and cnt == 3: + atom.SetProp("FG", "phosphoryl") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "O" - and len(neighbor.GetNeighbors()) == 2 - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "silyl_ether") - neighbor.SetProp("FG", "silyl_ether") - if num_Cl == 2 and charge == 0 and atom_num_neighbors == 4: + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "phosphoryl") + + ########################### Groups containing boron ########################### + elif atom_symbol == "B" and not in_ring and charge == 0: + num_C, num_O = 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if neighbor.GetSymbol() == "O": + num_O += 1 + + if num_O == 2 and atom_num_neighbors == 3: + cnt1, cnt2 = 0, 0 + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetTotalNumHs() == 1 + and neighbor.GetFormalCharge() == 0 + ): + cnt1 += 1 + if ( + neighbor.GetSymbol() == "O" + and neighbor.GetFormalCharge() == 0 + and len(neighbor.GetNeighbors()) == 2 + ): + cnt2 += 1 + if cnt1 == 2: + atom.SetProp("FG", "borono") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() == "Cl" - and neighbor.GetFormalCharge() == 0 - ): - atom.SetProp("FG", "dichlorosilane") - neighbor.SetProp("FG", "dichlorosilane") - if ( - num_C >= 3 - and charge == 0 - and atom_num_neighbors == 4 - and atom.GetProp("FG") != "silyl_ether" - ): - cnt = 0 - C_idx = [] + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "borono") + if cnt2 == 2: + atom.SetProp("FG", "boronate") for neighbor in atom_neighbors: - if ( - neighbor.GetSymbol() in ["C", "*"] - and neighbor.GetFormalCharge() == 0 - and neighbor.GetTotalNumHs() == 3 - ): - cnt += 1 - C_idx.append(neighbor.GetIdx()) - if cnt == 3: - atom.SetProp("FG", "trimethylsilyl") - for idx in C_idx: - mol.GetAtomWithIdx(idx).SetProp("FG", "trimethylsilyl") - - ########################### Groups containing halogen ########################### - elif ( - atom_symbol == "F" - and not in_ring - and charge == 0 - and atom.GetProp("FG") == "" - ): - atom.SetProp("FG", "fluoro") - elif ( - atom_symbol == "Cl" - and not in_ring - and charge == 0 - and atom.GetProp("FG") == "" - ): - atom.SetProp("FG", "chloro") - elif ( - atom_symbol == "Br" - and not in_ring - and charge == 0 - and atom.GetProp("FG") == "" - ): - atom.SetProp("FG", "bromo") - elif ( - atom_symbol == "I" - and not in_ring + if neighbor.GetSymbol() == "O": + neighbor.SetProp("FG", "boronate") + + if num_O == 1 and atom_num_neighbors == 3: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O" and neighbor.GetFormalCharge() == 0: + if neighbor.GetTotalNumHs() == 1: + atom.SetProp("FG", "borino") + neighbor.SetProp("FG", "borino") + if len(neighbor.GetNeighbors()) == 2: + atom.SetProp("FG", "borinate") + neighbor.SetProp("FG", "borinate") + + ########################### Groups containing silicon ########################### + elif atom_symbol == "Si" and not in_ring and charge == 0: + num_O, num_Cl, num_C = 0, 0, 0 + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "O": + num_O += 1 + if neighbor.GetSymbol() == "Cl": + num_Cl += 1 + if neighbor.GetSymbol() in ["C", "*"]: + num_C += 1 + if num_O == 1 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() == "O" + and len(neighbor.GetNeighbors()) == 2 + and neighbor.GetFormalCharge() == 0 + ): + atom.SetProp("FG", "silyl_ether") + neighbor.SetProp("FG", "silyl_ether") + if num_Cl == 2 and charge == 0 and atom_num_neighbors == 4: + for neighbor in atom_neighbors: + if neighbor.GetSymbol() == "Cl" and neighbor.GetFormalCharge() == 0: + atom.SetProp("FG", "dichlorosilane") + neighbor.SetProp("FG", "dichlorosilane") + if ( + num_C >= 3 and charge == 0 - and atom.GetProp("FG") == "" + and atom_num_neighbors == 4 + and atom.GetProp("FG") != "silyl_ether" ): - atom.SetProp("FG", "iodo") - else: - pass - - ########################### Groups containing other elements ########################### - if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: - if charge == 0: - atom.SetProp("FG", atom_symbol) - else: - atom.SetProp("FG", f"{atom_symbol}[{charge}]") + cnt = 0 + C_idx = [] + for neighbor in atom_neighbors: + if ( + neighbor.GetSymbol() in ["C", "*"] + and neighbor.GetFormalCharge() == 0 + and neighbor.GetTotalNumHs() == 3 + ): + cnt += 1 + C_idx.append(neighbor.GetIdx()) + if cnt == 3: + atom.SetProp("FG", "trimethylsilyl") + for idx in C_idx: + mol.GetAtomWithIdx(idx).SetProp("FG", "trimethylsilyl") + + ########################### Groups containing halogen ########################### + elif ( + atom_symbol == "F" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "fluoro") + elif ( + atom_symbol == "Cl" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "chloro") + elif ( + atom_symbol == "Br" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "bromo") + elif ( + atom_symbol == "I" + and not in_ring + and charge == 0 + and atom.GetProp("FG") == "" + ): + atom.SetProp("FG", "iodo") + else: + pass + + ########################### Groups containing other elements ########################### + if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: + if charge == 0: + atom.SetProp("FG", atom_symbol) else: - pass + atom.SetProp("FG", f"{atom_symbol}[{charge}]") + else: + pass - if atom_symbol == "*": - atom.SetProp("FG", "") + if atom_symbol == "*": + atom.SetProp("FG", "") def set_atom_map_num(mol): @@ -1868,8 +1848,10 @@ def find_atom_map(smiles): def get_structure(mol): + set_atom_map_num(mol) + fused_rings_groups: list[list[set[int]]] = set_ring_properties(mol) + detect_functional_group(mol) rings = mol.GetRingInfo().AtomRings() - fused_rings_groups: list[list[set[int]]] = _get_fused_rings_group(mol) splitting_bonds = set() for bond in mol.GetBonds(): @@ -1944,37 +1926,3 @@ def get_structure(mol): break return structure, BONDS - - -def _get_fused_rings_group(mol: Chem.Mol) -> list[list[set[int]]]: - rings = mol.GetRingInfo().AtomRings() - - fused_ring_groups = [] - visited = set() - - for i, ring1 in enumerate(rings): - if i in visited: - continue - fused_group = [set(ring1)] - visited.add(i) - for j, ring2 in enumerate(rings): - if j in visited or i == j: - continue - if len(set(ring1) & set(ring2)) >= 2: # At least 2 shared atoms - fused_group.append(set(ring2)) - visited.add(j) - if len(fused_group) > 1: - fused_ring_groups.append(fused_group) - - return fused_ring_groups - - -if __name__ == "__main__": - from rdkit.Chem import MolFromSmiles as s2m - - smiles = "CC(=O)OC1=CC=CC=C1C(=O)O" # Aspirin, CHEBI:15365 - acetylsalicylic acid - mol = s2m(smiles) - set_atom_map_num(mol) - detect_functional_group(mol) - get_structure(mol) - print(m2s(mol)) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 42a86b1..d8131c1 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,4 +1,3 @@ -import textwrap from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple @@ -8,11 +7,7 @@ from torch_geometric.data import Data as GeomData from chebai_graph.preprocessing.collate import GraphCollator -from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import ( - detect_functional_group, - get_structure, - set_atom_map_num, -) +from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties.constants import * @@ -354,8 +349,6 @@ def _construct_fg_to_atom_structure( """ # Rule-based algorithm to detect functional groups - set_atom_map_num(mol) - detect_functional_group(mol) structure, bonds = get_structure(mol) assert structure is not None, "Failed to detect functional groups." From 4651b5e18a00837b13965dba2209d234179cf9f2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 15:06:43 +0200 Subject: [PATCH 117/224] import data class in init --- chebai_graph/preprocessing/datasets/__init__.py | 13 +++++++++++++ configs/data/chebi50_augmented_gnn.yml | 2 +- configs/data/chebi50_graph.yml | 2 +- configs/data/chebi50_graph_properties.yml | 3 ++- configs/data/pubchem_graph.yml | 2 +- 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index e69de29..43aed70 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -0,0 +1,13 @@ +from .chebi import ( + ChEBI50GraphData, + ChEBI50GraphFGAugmentorReader, + ChEBI50GraphProperties, +) +from .pubchem import PubChemGraphProperties + +__all__ = [ + "ChEBI50GraphFGAugmentorReader", + "ChEBI50GraphProperties", + "ChEBI50GraphData", + "PubChemGraphProperties", +] diff --git a/configs/data/chebi50_augmented_gnn.yml b/configs/data/chebi50_augmented_gnn.yml index 8eee34b..e6482ef 100644 --- a/configs/data/chebi50_augmented_gnn.yml +++ b/configs/data/chebi50_augmented_gnn.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphFGAugmentorReader +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader init_args: properties: # Atom properties diff --git a/configs/data/chebi50_graph.yml b/configs/data/chebi50_graph.yml index 19c8753..12e8abd 100644 --- a/configs/data/chebi50_graph.yml +++ b/configs/data/chebi50_graph.yml @@ -1 +1 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphData +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphData diff --git a/configs/data/chebi50_graph_properties.yml b/configs/data/chebi50_graph_properties.yml index c84ac3e..4c52209 100644 --- a/configs/data/chebi50_graph_properties.yml +++ b/configs/data/chebi50_graph_properties.yml @@ -1,5 +1,6 @@ -class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI50GraphProperties +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties init_args: + chebi_version: 231 properties: - chebai_graph.preprocessing.properties.AtomType - chebai_graph.preprocessing.properties.NumAtomBonds diff --git a/configs/data/pubchem_graph.yml b/configs/data/pubchem_graph.yml index c21f188..855f93a 100644 --- a/configs/data/pubchem_graph.yml +++ b/configs/data/pubchem_graph.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.pubchem.PubChemGraphProperties +class_path: chebai_graph.preprocessing.datasets.PubChemGraphProperties init_args: transform: class_path: chebai_graph.preprocessing.transform_unlabeled.MaskAtom # mask atoms / bonds From cd4ae43af0b41f01001acb02539d5dd7ce67d7a3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 15:24:18 +0200 Subject: [PATCH 118/224] fix fg None error --- .../preprocessing/fg_detection/fg_aware_rule_based.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index e628a52..a7748f8 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -43,6 +43,7 @@ def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: ######## SET RING PROP ######## # Get ring information ring_info = mol.GetRingInfo() + fused_rings_groups: list[list[set[int]]] = [] if ring_info.NumRings() > 0: # Get list of atom rings @@ -54,7 +55,6 @@ def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: # Set of rings to process remaining_rings = [set(ring) for ring in atom_rings] - fused_rings_groups: list[list[set[int]]] = [] # Process each ring block while remaining_rings: @@ -95,9 +95,11 @@ def detect_functional_group(mol: Chem.Mol): if mol is None: return - ######## SET FUNCTIONAL GROUP PROP ######## for atom in mol.GetAtoms(): atom.SetProp("FG", "") + + ######## SET FUNCTIONAL GROUP PROP ######## + for atom in mol.GetAtoms(): atom_symbol = atom.GetSymbol() atom_neighbors = atom.GetNeighbors() atom_num_neighbors = len(atom_neighbors) From 729fe09be8fe56d53e44d98304d3873be7b737e3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 15:26:10 +0200 Subject: [PATCH 119/224] add usual prop for augmented graph with default value --- .../preprocessing/properties/__init__.py | 40 ++- .../properties/augmented_properties.py | 234 +++++++++++------- .../preprocessing/properties/properties.py | 4 +- .../preprocessing/reader/augmented_reader.py | 3 +- configs/data/chebi50_augmented_baseline.yml | 14 ++ 5 files changed, 198 insertions(+), 97 deletions(-) create mode 100644 configs/data/chebi50_augmented_baseline.yml diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index fcee6d3..8e5228b 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -21,10 +21,22 @@ ) from .augmented_properties import ( - AtomNodeLevel, - AtomFunctionalGroup, - AtomRingSize, - BondLevel, + AugAtomNodeLevel, + AugAtomFunctionalGroup, + AugAtomRingSize, + AugBondLevel, + AugAtomType, + AugNumAtomBonds, + AugAtomCharge, + AugAtomChirality, + AugAtomHybridization, + AugAtomNumHs, + AugAtomAromaticity, + AugBondAromaticity, + AugBondType, + AugBondInRing, + AugMoleculeNumRings, + AugRDKit2DNormalized, ) # isort: on @@ -46,8 +58,20 @@ "MoleculeNumRings", "RDKit2DNormalized", # -------- Augmented Molecular Properties -------- - "AtomNodeLevel", - "AtomFunctionalGroup", - "AtomRingSize", - "BondLevel", + "AugAtomNodeLevel", + "AugAtomFunctionalGroup", + "AugAtomRingSize", + "AugBondLevel", + "AugAtomType", + "AugNumAtomBonds", + "AugAtomCharge", + "AugAtomChirality", + "AugAtomHybridization", + "AugAtomNumHs", + "AugAtomAromaticity", + "AugBondAromaticity", + "AugBondType", + "AugBondInRing", + "AugMoleculeNumRings", + "AugRDKit2DNormalized", ] diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 49aa319..c6c51c0 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -5,83 +5,12 @@ from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder +from . import properties as pr from .constants import * -from .properties import AtomProperty, BondProperty -class AugmentedBondProperty(BondProperty, ABC): - MAIN_KEY = "edges" - - def get_property_value(self, augmented_mol: Dict) -> List: - if self.MAIN_KEY not in augmented_mol: - raise KeyError( - f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" - ) - - missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() - if missing_keys: - raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - - atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] - if not isinstance(atom_molecule, Chem.Mol): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' - ) - - prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] - - fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] - fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] - fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE] - - if ( - not isinstance(fg_atom_edges, dict) - or not isinstance(fg_edges, dict) - or not isinstance(fg_graph_node_edges, dict) - ): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) ' - f"must be an instance of dict containing its properties" - ) - - # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order - # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights - # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()]) - prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) - prop_list.extend( - [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] - ) - - num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2 - assert ( - len(prop_list) == num_directed_edges - ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " - - return prop_list - - @abstractmethod - def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): - pass - - def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): - value = self._get_bond_prop_value(bond, prop) - if not value: - # Every atom/node should have given value - raise ValueError(f"'{prop}' is set but empty.") - return value - - @staticmethod - def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): - if isinstance(bond, Chem.rdchem.Bond): - return bond.GetProp(prop) - elif isinstance(bond, dict): - return bond[prop] - else: - raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") - - -class AugmentedAtomProperty(AtomProperty, ABC): +# --------------------- Atom Properties ----------------------------- +class AugmentedAtomProperty(pr.AtomProperty, ABC): MAIN_KEY = "nodes" def get_property_value(self, augmented_mol: Dict): @@ -145,7 +74,7 @@ def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): ) -class AtomNodeLevel(AugmentedAtomProperty): +class AugAtomNodeLevel(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) @@ -153,23 +82,15 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, NODE_LEVEL) -class AtomFunctionalGroup(AugmentedAtomProperty): +class AugAtomFunctionalGroup(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, "FG") - def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): - if isinstance(atom, Chem.rdchem.Atom): - return atom.GetProp(prop) - elif isinstance(atom, dict): - return atom[prop] - else: - raise TypeError("Atom/Node should be of type `Chem.rdchem.Atom` or `dict`.") - -class AtomRingSize(AugmentedAtomProperty): +class AugAtomRingSize(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) @@ -186,9 +107,150 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str return 0 -class BondLevel(AugmentedBondProperty): +class AugNodeValueDefaulter(AugmentedAtomProperty, ABC): + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + if isinstance(atom, Chem.rdchem.Atom): + # Delegate to superclass method for atom + return super().get_atom_value(atom) + elif isinstance(atom, dict): + return 0 + else: + raise TypeError( + f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}" + ) + + +class AugAtomType(AugNodeValueDefaulter, pr.AtomType): ... + + +class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): ... + + +class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): ... + + +class AugAtomChirality(AugNodeValueDefaulter, pr.AtomChirality): ... + + +class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): ... + + +class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): ... + + +class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): ... + + +# --------------------- Bond Properties ------------------------------ +class AugmentedBondProperty(pr.BondProperty, ABC): + MAIN_KEY = "edges" + + def get_property_value(self, augmented_mol: Dict) -> List: + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' + ) + + prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] + + fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] + fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE] + + if ( + not isinstance(fg_atom_edges, dict) + or not isinstance(fg_edges, dict) + or not isinstance(fg_graph_node_edges, dict) + ): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) ' + f"must be an instance of dict containing its properties" + ) + + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()]) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] + ) + + num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2 + assert ( + len(prop_list) == num_directed_edges + ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + + return prop_list + + @abstractmethod + def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): + pass + + def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): + value = self._get_bond_prop_value(bond, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + @staticmethod + def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): + if isinstance(bond, Chem.rdchem.Bond): + return bond.GetProp(prop) + elif isinstance(bond, dict): + return bond[prop] + else: + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") + + +class AugBondLevel(AugmentedBondProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): return self._check_modify_bond_prop_value(bond, EDGE_LEVEL) + + +class AugBondValueDefaulter(AugmentedBondProperty, ABC): + def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): + if isinstance(bond, Chem.rdchem.Bond): + # Delegate to superclass method for bond + return super().get_bond_value(bond) + elif isinstance(bond, dict): + return 0 + else: + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") + + +class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): ... + + +class AugBondType(AugBondValueDefaulter, pr.BondType): ... + + +class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): ... + + +# --------------------- Molecular Properties ------------------------------ +class AugmentedMolecularProperty(pr.MolecularProperty, ABC): + def get_property_value(self, augmented_mol: Dict) -> list: + mol: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" + return super().get_property_value(mol) + + +class AugMoleculeNumRings(AugmentedMolecularProperty, pr.MoleculeNumRings): ... + + +class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ... diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index df7457d..0584eb2 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -32,8 +32,8 @@ def on_finish(self): def __str__(self): return self.name - def get_property_value(self, mol: Chem.rdchem.Mol | Dict): - raise NotImplementedError + @abstractmethod + def get_property_value(self, mol: Chem.rdchem.Mol | Dict): ... class AtomProperty(MolecularProperty, ABC): diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index d8131c1..616f1e8 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -396,7 +396,6 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): ring_size = len(connected_atoms) fg_nodes[self._num_of_nodes] = { NODE_LEVEL: FG_NODE_LEVEL, - # E.g., Fused Ring has size "5-6", indicating size of each connected ring in fused ring "FG": f"RING_{ring_size}", "RING": ring_size, } @@ -405,6 +404,8 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): ring_prop = atom.GetProp("RING") if not ring_prop: raise ValueError("Atom does not have a ring size set") + # TODO: discuss the case, should it be max or average + # An atom belonging to multiple rings in fused Ring has size "5-6", indicating size of each ring max_ring_size = max(list(map(int, ring_prop.split("-")))) atom.SetProp("FG", f"RING_{max_ring_size}") diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml new file mode 100644 index 0000000..77c7346 --- /dev/null +++ b/configs/data/chebi50_augmented_baseline.yml @@ -0,0 +1,14 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader +init_args: + properties: + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugBondType + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondAromaticity + #- chebai_graph.preprocessing.properties.AugMoleculeNumRings + - chebai_graph.preprocessing.properties.AugRDKit2DNormalized From 88ed36b6fa02483c2e7d54109d8fe4cde9b52bd0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 21:33:08 +0200 Subject: [PATCH 120/224] set tensor of zeros if None value is given to encoder --- .../preprocessing/property_encoder.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 0f822c9..d5f833a 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -48,7 +48,9 @@ def __init__(self, property, indices_dir=None, **kwargs): token.strip(): idx for idx, token in enumerate(pk) } self.index_length_start = len(self.cache) - self.offset = 0 + self._unk_token_idx = 0 + self._count_for_unk_token = 0 + self.offset = 1 @property def name(self): @@ -92,9 +94,18 @@ def on_finish(self): f"Now, the total length of the index of property {self.property.name} is {total_tokens}" ) + if self._count_for_unk_token > 0: + print( + f"{self.__class__.__name__} Encountered {self._count_for_unk_token} unknown tokens" + ) + def encode(self, token): """Returns a unique number for each token, automatically adds new tokens to the cache.""" - if not str(token) in self.cache: + if token is None: + self._count_for_unk_token += 1 + return torch.tensor([self._unk_token_idx]) + + if str(token) not in self.cache: self.cache[(str(token))] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) @@ -105,6 +116,9 @@ class OneHotEncoder(IndexEncoder): def __init__(self, property, n_labels: Optional[int] = None, **kwargs): super().__init__(property, **kwargs) self._encoding_length = n_labels + # To undo any offset set by index encoder as its not relevant for one-hot-encoder (no offset needed for some unknown/reserved token) + # Also, `torch.nn.functional.one_hot` that class values must be smaller than num_classes. + self.offset = 0 def get_encoding_length(self) -> int: return self._encoding_length or len(self.cache) @@ -131,6 +145,10 @@ def on_start(self, property_values): self.tokens_dict[token] = super().encode(token) def encode(self, token): + if token not in self.tokens_dict: + self._count_for_unk_token += 1 + return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64) + return torch.nn.functional.one_hot( self.tokens_dict[token], num_classes=self.get_encoding_length() ) @@ -144,6 +162,8 @@ def name(self): return "asis" def encode(self, token): + if token is None: + return torch.tensor([0]) return torch.tensor([token]) From 5202ef45d7f519f74875a9ca9d477729549cfb7a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 21 Jun 2025 21:34:04 +0200 Subject: [PATCH 121/224] return None for augmented nodes --- .../preprocessing/properties/__init__.py | 4 - .../properties/augmented_properties.py | 83 +++++++++++++------ configs/data/chebi50_augmented_baseline.yml | 1 - 3 files changed, 58 insertions(+), 30 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 8e5228b..76f7a77 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -28,14 +28,12 @@ AugAtomType, AugNumAtomBonds, AugAtomCharge, - AugAtomChirality, AugAtomHybridization, AugAtomNumHs, AugAtomAromaticity, AugBondAromaticity, AugBondType, AugBondInRing, - AugMoleculeNumRings, AugRDKit2DNormalized, ) @@ -65,13 +63,11 @@ "AugAtomType", "AugNumAtomBonds", "AugAtomCharge", - "AugAtomChirality", "AugAtomHybridization", "AugAtomNumHs", "AugAtomAromaticity", "AugBondAromaticity", "AugBondType", "AugBondInRing", - "AugMoleculeNumRings", "AugRDKit2DNormalized", ] diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index c6c51c0..33c52bd 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -52,10 +52,6 @@ def get_property_value(self, augmented_mol: Dict): return prop_list - @abstractmethod - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): - pass - def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): value = self._get_atom_prop_value(atom, prop) if not value: @@ -113,32 +109,63 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): # Delegate to superclass method for atom return super().get_atom_value(atom) elif isinstance(atom, dict): - return 0 + return None else: raise TypeError( f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}" ) -class AugAtomType(AugNodeValueDefaulter, pr.AtomType): ... +class AugAtomType(AugNodeValueDefaulter, pr.AtomType): + # This property uses OneHotEncoder as default encoder + # TODO: Can we return 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Currently, we return None which leads to zero-tensor for augmented nodes + # RDKit uses 0 as the atomic number for a "dummy atom", which usually means: + # A placeholder atom (e.g. [*], R#, or attachment points in SMARTS/SMILES). + # An undefined or wildcard atom. + # A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). + ... -class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): ... +class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): + # This property uses OneHotEncoder as default encoder + # Default return value for this property can't be zero, 0 is used for isolated atoms in molecule. It has to be None or actual node degree. + # TODO: Can return actual node degree/num of connections for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Currently, we return None which leads to zero-tensor for augmented nodes + # But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. Consider graph node too. + ... -class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): ... +class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): + # This property uses OneHotEncoder as default encoder + # Default return value for this property can't be zero, as atoms can have 0 charge. + # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Currently, we return None which leads to zero-tensor for augmented nodes + ... -class AugAtomChirality(AugNodeValueDefaulter, pr.AtomChirality): ... +class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): + # This property uses OneHotEncoder as default encoder + # TODO: Can return some `HybridizationType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType + # Currently, we return None which leads to zero-tensor for augmented nodes + ... -class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): ... +class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): + # This property uses OneHotEncoder as default encoder + # Default return value for this property can't be zero, as atoms can have 0 Hydrogen atoms attached which mean atoms is full balanced by bonding with other non-hydrogen atoms. + # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Currently, we return None which leads to zero-tensor for augmented nodes + ... -class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): ... - -class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): ... +class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): + # This property uses BoolEncoder as default encoder + # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + ... # --------------------- Bond Properties ------------------------------ @@ -193,10 +220,6 @@ def get_property_value(self, augmented_mol: Dict) -> List: return prop_list - @abstractmethod - def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): - pass - def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): value = self._get_bond_prop_value(bond, prop) if not value: @@ -228,29 +251,39 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): # Delegate to superclass method for bond return super().get_bond_value(bond) elif isinstance(bond, dict): - return 0 + return None else: raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") -class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): ... +class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): + # This property uses BoolEncoder as default encoder + # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + ... -class AugBondType(AugBondValueDefaulter, pr.BondType): ... +class AugBondType(AugBondValueDefaulter, pr.BondType): + # This property uses OneHotEncoder as default encoder + # TODO: Can return some `BondType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType + # Currently, we return None which leads to zero-tensor for augmented nodes + ... -class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): ... +class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): + # This property uses BoolEncoder as default encoder + # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + ... # --------------------- Molecular Properties ------------------------------ class AugmentedMolecularProperty(pr.MolecularProperty, ABC): def get_property_value(self, augmented_mol: Dict) -> list: - mol: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + mol: Chem.Mol = augmented_mol[AugmentedAtomProperty.MAIN_KEY]["atom_nodes"] assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" return super().get_property_value(mol) -class AugMoleculeNumRings(AugmentedMolecularProperty, pr.MoleculeNumRings): ... - - class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ... diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml index 77c7346..7f9985c 100644 --- a/configs/data/chebi50_augmented_baseline.yml +++ b/configs/data/chebi50_augmented_baseline.yml @@ -10,5 +10,4 @@ init_args: - chebai_graph.preprocessing.properties.AugBondType - chebai_graph.preprocessing.properties.AugBondInRing - chebai_graph.preprocessing.properties.AugBondAromaticity - #- chebai_graph.preprocessing.properties.AugMoleculeNumRings - chebai_graph.preprocessing.properties.AugRDKit2DNormalized From 33d60a936513395951611ec6e937c051e85a454c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 16:23:52 +0200 Subject: [PATCH 122/224] move base prop class into another base file --- .../preprocessing/properties/__init__.py | 6 +-- chebai_graph/preprocessing/properties/base.py | 51 +++++++++++++++++++ .../preprocessing/properties/properties.py | 50 +----------------- 3 files changed, 56 insertions(+), 51 deletions(-) create mode 100644 chebai_graph/preprocessing/properties/base.py diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 76f7a77..03bc8b3 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -2,10 +2,10 @@ # but it has to be imported after properties module, to avoid circular imports # This is because augmented properties module imports from properties module # isort: off + +from .base import MolecularProperty, AtomProperty, BondProperty + from .properties import ( - MolecularProperty, - AtomProperty, - BondProperty, AtomType, NumAtomBonds, AtomCharge, diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py new file mode 100644 index 0000000..b0e68a3 --- /dev/null +++ b/chebai_graph/preprocessing/properties/base.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod + +import rdkit.Chem as Chem + +from chebai_graph.preprocessing.property_encoder import IndexEncoder, PropertyEncoder + + +class MolecularProperty(ABC): + def __init__(self, encoder: PropertyEncoder | None = None): + if encoder is None: + encoder = IndexEncoder(self) + self.encoder = encoder + + @property + def name(self): + """Unique identifier for this property.""" + return self.__class__.__name__ + + def on_finish(self): + """Called after dataset processing is done.""" + self.encoder.on_finish() + + def __str__(self): + return self.name + + @abstractmethod + def get_property_value(self, mol: Chem.rdchem.Mol | dict): ... + + +class AtomProperty(MolecularProperty, ABC): + """Property of an atom.""" + + def get_property_value(self, mol: Chem.rdchem.Mol): + return [self.get_atom_value(atom) for atom in mol.GetAtoms()] + + @abstractmethod + def get_atom_value(self, atom: Chem.rdchem.Atom): + pass + + +class BondProperty(MolecularProperty, ABC): + def get_property_value(self, mol: Chem.rdchem.Mol): + return [self.get_bond_value(bond) for bond in mol.GetBonds()] + + @abstractmethod + def get_bond_value(self, bond: Chem.rdchem.Bond): + pass + + +class MoleculeProperty(MolecularProperty): + """Global property of a molecule.""" diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index 0584eb2..8e0a425 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -1,5 +1,4 @@ -from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Optional import numpy as np import rdkit.Chem as Chem @@ -8,56 +7,11 @@ from chebai_graph.preprocessing.property_encoder import ( AsIsEncoder, BoolEncoder, - IndexEncoder, OneHotEncoder, PropertyEncoder, ) - -class MolecularProperty(ABC): - def __init__(self, encoder: Optional[PropertyEncoder] = None): - if encoder is None: - encoder = IndexEncoder(self) - self.encoder = encoder - - @property - def name(self): - """Unique identifier for this property.""" - return self.__class__.__name__ - - def on_finish(self): - """Called after dataset processing is done.""" - self.encoder.on_finish() - - def __str__(self): - return self.name - - @abstractmethod - def get_property_value(self, mol: Chem.rdchem.Mol | Dict): ... - - -class AtomProperty(MolecularProperty, ABC): - """Property of an atom.""" - - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_atom_value(atom) for atom in mol.GetAtoms()] - - @abstractmethod - def get_atom_value(self, atom: Chem.rdchem.Atom): - pass - - -class BondProperty(MolecularProperty, ABC): - def get_property_value(self, mol: Chem.rdchem.Mol): - return [self.get_bond_value(bond) for bond in mol.GetBonds()] - - @abstractmethod - def get_bond_value(self, bond: Chem.rdchem.Bond): - pass - - -class MoleculeProperty(MolecularProperty): - """Global property of a molecule.""" +from .base import AtomProperty, BondProperty, MolecularProperty class AtomType(AtomProperty): From ed6787dfc0e2a9f26147a2b669a4ec1f9ab2b8d0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 17:13:11 +0200 Subject: [PATCH 123/224] add wrapper to use mol prop for augmented graph --- .../properties/augmented_properties.py | 8 ++- chebai_graph/preprocessing/properties/base.py | 57 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 33c52bd..ed6538d 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import ABC from typing import Dict, List, Optional from rdkit import Chem @@ -6,6 +6,7 @@ from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder from . import properties as pr +from .base import FrozenPropertyAlias from .constants import * @@ -103,7 +104,8 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str return 0 -class AugNodeValueDefaulter(AugmentedAtomProperty, ABC): +class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): # Delegate to superclass method for atom @@ -245,7 +247,7 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): return self._check_modify_bond_prop_value(bond, EDGE_LEVEL) -class AugBondValueDefaulter(AugmentedBondProperty, ABC): +class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): if isinstance(bond, Chem.rdchem.Bond): # Delegate to superclass method for bond diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index b0e68a3..3614772 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from types import MappingProxyType import rdkit.Chem as Chem @@ -49,3 +50,59 @@ def get_bond_value(self, bond: Chem.rdchem.Bond): class MoleculeProperty(MolecularProperty): """Global property of a molecule.""" + + +class FrozenPropertyAlias(MolecularProperty, ABC): + """ + Wrapper base class for augmented graph properties that want to reuse existing molecular properties. + + This class allows augmented graph property classes to inherit both from this wrapper and a standard + molecular property (from `.properties`), enabling reuse of their encoders and index files without + modifying them. + + Key Features: + - Prevents new tokens from being added to the encoder cache by freezing it. + - Automatically aligns the property name (used for encoder/index resolution) with the inherited + base property by removing the "Aug" prefix from the class name. + + Usage: + The derived class should: + - Inherit from `FrozenPropertyAlias` **and** a valid base molecular property class. + - Have a name starting with "Aug" (e.g., `AugAtomType`), which will be resolved to `AtomType`. + + Example: + ```python + class AugAtomType(FrozenPropertyAlias, AtomType): + ... + ``` + Note: + Subclass name of this class should with prefix "Aug" for above effect to take place. + + This allows `AugAtomType` to reuse the encoder, index files, and logic of `AtomType` while + integrating into augmented graph pipelines. + """ + + def __init__(self, encoder: PropertyEncoder | None = None): + super().__init__(encoder) + # Lock the encoder's cache to prevent adding new tokens + if hasattr(self.encoder, "cache") and isinstance(self.encoder.cache, dict): + self.encoder.cache = MappingProxyType(self.encoder.cache) + + @property + def name(self): + """ + Unique identifier for this property, with 'Aug' prefix removed if present. + This allows the encoder to reuse index files of the corresponding base property. + """ + class_name = self.__class__.__name__ + return class_name[3:] if class_name.startswith("Aug") else class_name + + def on_finish(self): + if ( + hasattr(self.encoder, "cache") + and len(self.encoder.cache) > self.encoder.index_length_start + ): + raise ValueError( + f"{self.__class__.__name__} attempted to add new tokens to a {self.encoder.index_path}" + ) + super().on_finish() From ecfab05c67bc6b6ad4799a89cd939bae4483318a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 18:29:49 +0200 Subject: [PATCH 124/224] avoid * imports for linters --- .../properties/augmented_properties.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index ed6538d..5f34554 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -5,13 +5,13 @@ from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder +from . import constants as k from . import properties as pr -from .base import FrozenPropertyAlias -from .constants import * +from .base import AtomProperty, BondProperty, FrozenPropertyAlias # --------------------- Atom Properties ----------------------------- -class AugmentedAtomProperty(pr.AtomProperty, ABC): +class AugmentedAtomProperty(AtomProperty, ABC): MAIN_KEY = "nodes" def get_property_value(self, augmented_mol: Dict): @@ -76,7 +76,7 @@ def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): - return self._check_modify_atom_prop_value(atom, NODE_LEVEL) + return self._check_modify_atom_prop_value(atom, k.NODE_LEVEL) class AugAtomFunctionalGroup(AugmentedAtomProperty): @@ -171,7 +171,7 @@ class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): # --------------------- Bond Properties ------------------------------ -class AugmentedBondProperty(pr.BondProperty, ABC): +class AugmentedBondProperty(BondProperty, ABC): MAIN_KEY = "edges" def get_property_value(self, augmented_mol: Dict) -> List: @@ -180,21 +180,21 @@ def get_property_value(self, augmented_mol: Dict) -> List: f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" ) - missing_keys = EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() + missing_keys = k.EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() if missing_keys: raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][WITHIN_ATOMS_EDGE] + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][k.WITHIN_ATOMS_EDGE] if not isinstance(atom_molecule, Chem.Mol): raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["{WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' + f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' ) prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] - fg_atom_edges = augmented_mol[self.MAIN_KEY][ATOM_FG_EDGE] - fg_edges = augmented_mol[self.MAIN_KEY][WITHIN_FG_EDGE] - fg_graph_node_edges = augmented_mol[self.MAIN_KEY][FG_GRAPHNODE_EDGE] + fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] + fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.FG_GRAPHNODE_EDGE] if ( not isinstance(fg_atom_edges, dict) @@ -202,7 +202,7 @@ def get_property_value(self, augmented_mol: Dict) -> List: or not isinstance(fg_graph_node_edges, dict) ): raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["{ATOM_FG_EDGE}"]/["{WITHIN_FG_EDGE}"]/["{FG_GRAPHNODE_EDGE}"]) ' + f'augmented_mol["{self.MAIN_KEY}"](["{k.ATOM_FG_EDGE}"]/["{k.WITHIN_FG_EDGE}"]/["{k.FG_GRAPHNODE_EDGE}"]) ' f"must be an instance of dict containing its properties" ) @@ -215,7 +215,7 @@ def get_property_value(self, augmented_mol: Dict) -> List: [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] ) - num_directed_edges = augmented_mol[self.MAIN_KEY][NUM_EDGES] // 2 + num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 assert ( len(prop_list) == num_directed_edges ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " @@ -244,7 +244,7 @@ def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): - return self._check_modify_bond_prop_value(bond, EDGE_LEVEL) + return self._check_modify_bond_prop_value(bond, k.EDGE_LEVEL) class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): From c943d08a1dab2a77b854bc0be0ad4c0826293cad Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 18:37:22 +0200 Subject: [PATCH 125/224] restore aug prop dervied from aug graph to original cls name --- .../preprocessing/properties/augmented_properties.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 5f34554..c78cdf7 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -71,7 +71,7 @@ def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): ) -class AugAtomNodeLevel(AugmentedAtomProperty): +class AtomNodeLevel(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) @@ -79,7 +79,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, k.NODE_LEVEL) -class AugAtomFunctionalGroup(AugmentedAtomProperty): +class AtomFunctionalGroup(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) @@ -87,7 +87,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return self._check_modify_atom_prop_value(atom, "FG") -class AugAtomRingSize(AugmentedAtomProperty): +class AtomRingSize(AugmentedAtomProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) @@ -105,7 +105,6 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): # Delegate to superclass method for atom @@ -239,7 +238,7 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") -class AugBondLevel(AugmentedBondProperty): +class BondLevel(AugmentedBondProperty): def __init__(self, encoder: Optional[PropertyEncoder] = None): super().__init__(encoder or OneHotEncoder(self)) From f148b0a04498e34a296fca1471ea2b78c855bc2e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 18:47:46 +0200 Subject: [PATCH 126/224] fix start imports in reader + minor other fix --- .../preprocessing/properties/__init__.py | 16 +++++----- .../preprocessing/reader/augmented_reader.py | 30 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 03bc8b3..281055f 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -21,10 +21,10 @@ ) from .augmented_properties import ( - AugAtomNodeLevel, - AugAtomFunctionalGroup, - AugAtomRingSize, - AugBondLevel, + AtomNodeLevel, + AtomFunctionalGroup, + AtomRingSize, + BondLevel, AugAtomType, AugNumAtomBonds, AugAtomCharge, @@ -56,10 +56,10 @@ "MoleculeNumRings", "RDKit2DNormalized", # -------- Augmented Molecular Properties -------- - "AugAtomNodeLevel", - "AugAtomFunctionalGroup", - "AugAtomRingSize", - "AugBondLevel", + "AtomNodeLevel", + "AtomFunctionalGroup", + "AtomRingSize", + "BondLevel", "AugAtomType", "AugNumAtomBonds", "AugAtomCharge", diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 616f1e8..a2f66b0 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -9,7 +9,7 @@ from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure from chebai_graph.preprocessing.properties import MolecularProperty -from chebai_graph.preprocessing.properties.constants import * +from chebai_graph.preprocessing.properties import constants as k class _AugmentorReader(DataReader, ABC): @@ -180,7 +180,7 @@ def _read_data(self, smiles: str) -> GeomData | None: # Empty features initialized; node and edge features can be added later x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) - edge_attr = torch.zeros((augmented_molecule["edges"][NUM_EDGES], 0)) + edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0)) assert ( edge_index.shape[0] == 2 @@ -286,11 +286,11 @@ def _augment_graph_structure( self._num_of_edges == total_edges ), f"Mismatch in number of edges: expected {total_edges}, got {self._num_of_edges}" edge_info = { - WITHIN_ATOMS_EDGE: mol, - ATOM_FG_EDGE: atom_fg_edges, - WITHIN_FG_EDGE: internal_fg_edges, - FG_GRAPHNODE_EDGE: fg_to_graph_edges, - NUM_EDGES: self._num_of_edges * 2, # Undirected edges + k.WITHIN_ATOMS_EDGE: mol, + k.ATOM_FG_EDGE: atom_fg_edges, + k.WITHIN_FG_EDGE: internal_fg_edges, + k.FG_GRAPHNODE_EDGE: fg_to_graph_edges, + k.NUM_EDGES: self._num_of_edges * 2, # Undirected edges } return undirected_edge_index, node_info, edge_info @@ -303,9 +303,9 @@ def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: mol (Chem.Mol): RDKit molecule. """ for atom in mol.GetAtoms(): - atom.SetProp(NODE_LEVEL, ATOM_NODE_LEVEL) + atom.SetProp(k.NODE_LEVEL, k.ATOM_NODE_LEVEL) for bond in mol.GetBonds(): - bond.SetProp(EDGE_LEVEL, WITHIN_ATOMS_EDGE) + bond.SetProp(k.EDGE_LEVEL, k.WITHIN_ATOMS_EDGE) @staticmethod def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: @@ -375,7 +375,7 @@ def _construct_fg_to_atom_structure( fg_atom_edge_index[0].append(self._num_of_nodes) fg_atom_edge_index[1].append(atom_idx) atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { - EDGE_LEVEL: ATOM_FG_EDGE + k.EDGE_LEVEL: k.ATOM_FG_EDGE } self._num_of_edges += 1 @@ -395,7 +395,7 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = len(connected_atoms) fg_nodes[self._num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, + k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": f"RING_{ring_size}", "RING": ring_size, } @@ -445,7 +445,7 @@ def _set_fg_prop(self, connected_atoms, fg_nodes): raise AssertionError("Expected at least one atom with a functional group.") fg_nodes[self._num_of_nodes] = { - NODE_LEVEL: FG_NODE_LEVEL, + k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": representative_atom.GetProp("FG"), "RING": 0, } @@ -481,7 +481,7 @@ def add_fg_internal_edge(source_fg, target_fg): # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG internal_edge_index[0].append(source_fg) internal_edge_index[1].append(target_fg) - internal_fg_edges[edge_str] = {EDGE_LEVEL: WITHIN_FG_EDGE} + internal_fg_edges[edge_str] = {k.EDGE_LEVEL: k.WITHIN_FG_EDGE} self._num_of_edges += 1 for bond in bonds: @@ -528,7 +528,7 @@ def _construct_fg_to_graph_node_structure( - Graph-level node attribute - FG to Graph Edge attributes """ - graph_node = {NODE_LEVEL: GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} + graph_node = {k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} fg_graph_edges = {} graph_edge_index = [[], []] @@ -537,7 +537,7 @@ def _construct_fg_to_graph_node_structure( graph_edge_index[0].append(self._num_of_nodes) graph_edge_index[1].append(fg_id) fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { - EDGE_LEVEL: FG_GRAPHNODE_EDGE + k.EDGE_LEVEL: k.FG_GRAPHNODE_EDGE } self._num_of_edges += 1 self._num_of_nodes += 1 From d5798587c06eab75ba8784c7ce1868ff429d42e2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 18:58:57 +0200 Subject: [PATCH 127/224] remove lambd exp as suggested by ruff linter https://docs.astral.sh/ruff/rules/lambda-assignment/ --- chebai_graph/preprocessing/datasets/chebi.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 27a3c49..e7fe967 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -100,23 +100,22 @@ def _setup_properties(self): idents = [row["ident"] for row in raw_data] features = [row["features"] for row in raw_data] - # use vectorized version of encode function, apply only if value is present - enc_if_not_none = lambda encode, value: ( - [encode(atom_v) for atom_v in value] - if value is not None and len(value) > 0 - else None - ) + def enc_if_not_none(encode, value): + if value is not None and len(value) > 0: + return [encode(atom_v) for atom_v in value] + else: + return None for property in self.properties: if not os.path.isfile(self.get_property_path(property)): rank_zero_info(f"Processing property {property.name}") # read all property values first, then encode - rank_zero_info(f"\tReading property valeus...") + rank_zero_info(f"\tReading property values of {property.name}...") property_values = [ self.reader.read_property(feat, property) for feat in tqdm.tqdm(features) ] - rank_zero_info(f"\tEncoding property values...") + rank_zero_info(f"\tEncoding property values of {property.name}...") property.encoder.on_start(property_values=property_values) encoded_values = [ enc_if_not_none(property.encoder.encode, value) From 243722b1bb6e34293d7fa18618717ebdc740e703 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 23 Jun 2025 21:35:24 +0200 Subject: [PATCH 128/224] remove chebi version --- configs/data/chebi50_graph_properties.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/data/chebi50_graph_properties.yml b/configs/data/chebi50_graph_properties.yml index 4c52209..0b770b2 100644 --- a/configs/data/chebi50_graph_properties.yml +++ b/configs/data/chebi50_graph_properties.yml @@ -1,6 +1,5 @@ class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties init_args: - chebi_version: 231 properties: - chebai_graph.preprocessing.properties.AtomType - chebai_graph.preprocessing.properties.NumAtomBonds From 751f63e53ecc910e7af7c644a17d06c240edc1bf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 26 Jun 2025 15:39:02 +0200 Subject: [PATCH 129/224] seperate pooling for atom and augmented nodes One vector: average of atom embeddings One vector: average of augmented node embeddings https://github.com/ChEB-AI/python-chebai-graph/pull/2#issuecomment-3006390618 --- chebai_graph/models/__init__.py | 3 +- chebai_graph/models/graph.py | 62 +++++++++++++++++++ chebai_graph/preprocessing/datasets/chebi.py | 5 ++ .../preprocessing/reader/augmented_reader.py | 25 +++++++- configs/model/gnn_resgated_aug.yml | 13 ++++ 5 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 configs/model/gnn_resgated_aug.yml diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index 2a46f65..d091527 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,3 +1,4 @@ from ._gat import GATModelWrapper +from .graph import ResGatedAugmentedGraphPred -__all__ = ["GATModelWrapper"] +__all__ = ["GATModelWrapper", "ResGatedAugmentedGraphPred"] diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index 5da9a62..ff53e07 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -188,6 +188,68 @@ def forward(self, batch): return a +class ResGatedAugmentedGraphPred(GraphBaseNet): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedAugmentedGraphPred" + + def __init__( + self, + config: typing.Dict, + n_linear_layers=2, + **kwargs, + ): + super().__init__(**kwargs) + self.gnn = ResGatedGraphConvNetBase(config, **kwargs) + self.linear_layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + self.gnn.hidden_length + + (i == 0) * self.gnn.n_molecule_properties + + (i == 0) * self.gnn.hidden_length, + self.gnn.hidden_length, + ) + for i in range(n_linear_layers - 1) + ] + ) + self.final_layer = nn.Linear(self.gnn.hidden_length, self.out_dim) + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_atom_node = graph_data.is_atom_node.bool() # Boolean mask: shape [num_nodes] + is_augmented_node = ~is_atom_node + + node_embeddings = self.gnn(batch) + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + augmented_node_embeddings = node_embeddings[is_augmented_node] + augmented_node_batch = graph_data.batch[is_augmented_node] + + # Scatter add separately + graph_vec_atoms = scatter_add(atom_embeddings, atom_batch, dim=0) + graph_vec_augmented_nodes = scatter_add( + augmented_node_embeddings, augmented_node_batch, dim=0 + ) + + # Concatenate both + graph_vector = torch.cat( + [ + graph_vec_atoms, + graph_data.molecule_attr, + graph_vec_augmented_nodes, + ], + dim=1, + ) + + for lin in self.linear_layers: + a = self.gnn.activation(lin(graph_vector)) + a = self.final_layer(a) + return a + + class ResGatedGraphConvNetPretrain(GraphBaseNet): """For pretraining. BaseNet with an additional output layer for predicting atom properties""" diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index e7fe967..12cf568 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -178,11 +178,16 @@ def _merge_props_into_base(self, row): ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) + + is_atom_node = ( + geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None + ) return GeomData( x=x, edge_index=geom_data.edge_index, edge_attr=edge_attr, molecule_attr=molecule_attr, + is_atom_node=is_atom_node, ) def load_processed_data_from_file(self, filename): diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index a2f66b0..2972503 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,3 +1,4 @@ +import sys from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple @@ -179,7 +180,12 @@ def _read_data(self, smiles: str) -> GeomData | None: self.mol_object_buffer[smiles] = augmented_molecule # Empty features initialized; node and edge features can be added later - x = torch.zeros((augmented_molecule["nodes"]["num_nodes"], 0)) + NUM_NODES = augmented_molecule["nodes"]["num_nodes"] + assert ( + NUM_NODES is not None and NUM_NODES > 1 + ), "Num of nodes in augmented graph should be more than 1" + + x = torch.zeros((NUM_NODES, 0)) edge_attr = torch.zeros((augmented_molecule["edges"][k.NUM_EDGES], 0)) assert ( @@ -194,7 +200,14 @@ def _read_data(self, smiles: str) -> GeomData | None: len(set(edge_index[0].tolist())) == x.shape[0] ), f"Number of unique source nodes in edge_index ({len(set(edge_index[0].tolist()))}) does not match number of nodes in x ({x.shape[0]})" - return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Create a boolean mask: True for atom, False for augmented + is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) + NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() + is_atom_mask[:NUM_ATOM_NODES] = True + + return GeomData( + x=x, edge_index=edge_index, edge_attr=edge_attr, is_atom_node=is_atom_mask + ) def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: """ @@ -267,6 +280,14 @@ def _augment_graph_structure( assert ( self._num_of_nodes == total_atoms ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._num_of_nodes}" + assert sys.version_info >= ( + 3, + 7, + ), "This code requires Python 3.7 or higher." + # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order + # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights + # https://mail.python.org/pipermail/python-dev/2017-December/151283.html + # Order preservation is necessary to to create `is_atom_node` mask node_info = { "atom_nodes": mol, "fg_nodes": fg_nodes, diff --git a/configs/model/gnn_resgated_aug.yml b/configs/model/gnn_resgated_aug.yml new file mode 100644 index 0000000..d7869ca --- /dev/null +++ b/configs/model/gnn_resgated_aug.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedAugmentedGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_length: 256 + hidden_length: 512 + dropout_rate: 0.1 + n_conv_layers: 3 + n_linear_layers: 3 + n_atom_properties: 158 + n_bond_properties: 7 + n_molecule_properties: 200 From 9f40c836751be6867beb97fe8ac12d0548260238 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 23:52:43 +0200 Subject: [PATCH 130/224] hydrogen bond donor/acceptor prop for fg --- .../properties/augmented_properties.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index c78cdf7..f6cbaf2 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -3,7 +3,11 @@ from rdkit import Chem -from chebai_graph.preprocessing.property_encoder import OneHotEncoder, PropertyEncoder +from chebai_graph.preprocessing.property_encoder import ( + BoolEncoder, + OneHotEncoder, + PropertyEncoder, +) from . import constants as k from . import properties as pr @@ -104,6 +108,43 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str return 0 +class IsHydrogenBondDonorFG(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + # fmt: off + # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L26-L31 + self._hydrogen_bond_donor: set[str] = { + 'hydroxyl', 'hydroperoxy', 'primary_amine', 'secondary_amine', + 'hydrazone', 'primary_ketimine', 'secondary_ketimine', 'primary_aldimine', + 'amide', 'sulfhydryl', 'sulfonic_acid', 'thiolester', 'hemiacetal', + 'hemiketal', 'carboxyl', 'aldoxime', 'ketoxime' + } + # fmt: on + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + fg = self._check_modify_atom_prop_value(atom, "FG") + return fg in self._hydrogen_bond_donor + + +class IsHydrogenBondAcceptorFG(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + # fmt: off + # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L33-L39 + self._hydrogen_bond_acceptor: set[str] = { + 'ether', 'peroxy', 'haloformyl', 'ketone', 'aldehyde', 'carboxylate', + 'carboxyl', 'ester', 'ketal', 'carbonate_ester', 'carboxylic_anhydride', + 'primary_amine', 'secondary_amine', 'tertiary_amine', '4_ammonium_ion', + 'hydrazone', 'primary_ketimine', 'secondary_ketimine', 'primary_aldimine', + 'amide', 'sulfhydryl', 'sulfonic_acid', 'thiolester', 'aldoxime', 'ketoxime' + } + # fmt: on + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + fg = self._check_modify_atom_prop_value(atom, "FG") + return fg in self._hydrogen_bond_acceptor + + class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): From d6004746b58ec86ae59d45e35d5482dad86237ab Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 29 Jun 2025 00:47:59 +0200 Subject: [PATCH 131/224] is_alkyl fg prop --- .../preprocessing/properties/__init__.py | 6 +++++ .../properties/augmented_properties.py | 8 +++++++ .../preprocessing/reader/augmented_reader.py | 24 +++++++++++++------ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 281055f..3ca40a9 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -24,6 +24,9 @@ AtomNodeLevel, AtomFunctionalGroup, AtomRingSize, + IsHydrogenBondDonorFG, + IsHydrogenBondAcceptorFG, + IsFGAlkyl, BondLevel, AugAtomType, AugNumAtomBonds, @@ -59,6 +62,9 @@ "AtomNodeLevel", "AtomFunctionalGroup", "AtomRingSize", + "IsHydrogenBondDonorFG", + "IsHydrogenBondAcceptorFG", + "IsFGAlkyl", "BondLevel", "AugAtomType", "AugNumAtomBonds", diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index f6cbaf2..c174340 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -145,6 +145,14 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): return fg in self._hydrogen_bond_acceptor +class IsFGAlkyl(AugmentedAtomProperty): + def __init__(self, encoder: Optional[PropertyEncoder] = None): + super().__init__(encoder or BoolEncoder(self)) + + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + return int(self._check_modify_atom_prop_value(atom, "is_alkyl")) + + class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 2972503..6c8ba7a 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,3 +1,4 @@ +import re import sys from abc import ABC, abstractmethod from typing import Dict, List, Optional, Tuple @@ -379,9 +380,10 @@ def _construct_fg_to_atom_structure( fg_to_atoms_map = {} molecule_atoms_set = set() - for _, fg_group in structure.items(): + for fg_smiles, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = fg_group is_ring_fg = fg_group["is_ring_fg"] + is_alkyl = 0 connected_atoms = [] # Build edge index for fg to atom nodes connections @@ -406,7 +408,7 @@ def _construct_fg_to_atom_structure( if is_ring_fg: self._set_ring_fg_prop(connected_atoms, fg_nodes) else: - self._set_fg_prop(connected_atoms, fg_nodes) + self._set_fg_prop(connected_atoms, fg_nodes, fg_smiles) self._num_of_nodes += 1 @@ -419,6 +421,7 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": f"RING_{ring_size}", "RING": ring_size, + "is_alkyl": "0", } # In this case, all atoms of Ring/Fused Ring are assigned the ring size as functional group for atom in connected_atoms: @@ -429,8 +432,9 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): # An atom belonging to multiple rings in fused Ring has size "5-6", indicating size of each ring max_ring_size = max(list(map(int, ring_prop.split("-")))) atom.SetProp("FG", f"RING_{max_ring_size}") + atom.SetProp("is_alkyl", "0") - def _set_fg_prop(self, connected_atoms, fg_nodes): + def _set_fg_prop(self, connected_atoms, fg_nodes, fg_smiles): fg_set = {atom.GetProp("FG") for atom in connected_atoms} if not fg_set: raise ValueError( @@ -458,10 +462,15 @@ def _set_fg_prop(self, connected_atoms, fg_nodes): "All Connected atoms must belong to one functional group or None" ) - # Select any one connected atom to get FG type and ring size - representative_atom = next( - (atom for atom in connected_atoms if atom.GetProp("FG")), None - ) + check = re.sub(r"[CH\-\(\)\[\]/\\@]", "", fg_smiles) + is_alkyl = "1" if len(check) == 0 else "0" + + representative_atom = None + for atom in connected_atoms: + if atom.GetProp("FG"): + representative_atom = atom + atom.SetProp("is_alkyl", is_alkyl) + if representative_atom is None: raise AssertionError("Expected at least one atom with a functional group.") @@ -469,6 +478,7 @@ def _set_fg_prop(self, connected_atoms, fg_nodes): k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": representative_atom.GetProp("FG"), "RING": 0, + "is_alkyl": is_alkyl, } def _construct_fg_level_structure( From 8353cdd7e8665671f77f9624bc10d3acd5330c78 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 30 Jun 2025 21:34:26 +0200 Subject: [PATCH 132/224] add unk tokens to prop files --- .../bin/AtomCharge/indices_one_hot.txt | 1 + .../bin/AtomNumHs/indices_one_hot.txt | 1 + .../bin/BondType/indices_one_hot.txt | 1 + .../bin/NumAtomBonds/indices_one_hot.txt | 1 + .../properties/augmented_properties.py | 18 +++++++++++++++--- 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt index b0f2214..e929c5d 100644 --- a/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt @@ -11,3 +11,4 @@ 6 7 -5 +unk diff --git a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt index f024a9d..b2f1b65 100644 --- a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt @@ -5,3 +5,4 @@ 1 5 6 +unk diff --git a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt index 97ae8be..f36bdf7 100644 --- a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt @@ -3,3 +3,4 @@ SINGLE AROMATIC TRIPLE DOUBLE +UNSPECIFIED diff --git a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt index 7036755..9e58e88 100644 --- a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt @@ -9,3 +9,4 @@ 7 10 12 +unk diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index c174340..0474a42 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -154,12 +154,14 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): + DEFAULT_PROP_VAL = None + def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): # Delegate to superclass method for atom return super().get_atom_value(atom) elif isinstance(atom, dict): - return None + return self.DEFAULT_PROP_VAL else: raise TypeError( f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}" @@ -176,6 +178,7 @@ class AugAtomType(AugNodeValueDefaulter, pr.AtomType): # An undefined or wildcard atom. # A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). ... + DEFAULT_PROP_VAL = 0 class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): @@ -185,6 +188,7 @@ class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): # Currently, we return None which leads to zero-tensor for augmented nodes # But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. Consider graph node too. ... + DEFAULT_PROP_VAL = "unk" class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): @@ -193,6 +197,7 @@ class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes # Currently, we return None which leads to zero-tensor for augmented nodes ... + DEFAULT_PROP_VAL = "unk" class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): @@ -201,6 +206,7 @@ class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType # Currently, we return None which leads to zero-tensor for augmented nodes ... + DEFAULT_PROP_VAL = Chem.rdchem.HybridizationType.UNSPECIFIED class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): @@ -209,6 +215,7 @@ class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes # Currently, we return None which leads to zero-tensor for augmented nodes ... + DEFAULT_PROP_VAL = "unk" class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): @@ -296,12 +303,14 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): + DEFAULT_PROP_VAL = None + def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): if isinstance(bond, Chem.rdchem.Bond): # Delegate to superclass method for bond return super().get_bond_value(bond) elif isinstance(bond, dict): - return None + return self.DEFAULT_PROP_VAL else: raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") @@ -319,6 +328,7 @@ class AugBondType(AugBondValueDefaulter, pr.BondType): # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType # Currently, we return None which leads to zero-tensor for augmented nodes ... + DEFAULT_PROP_VAL = Chem.rdchem.BondType.UNSPECIFIED class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): @@ -336,4 +346,6 @@ def get_property_value(self, augmented_mol: Dict) -> list: return super().get_property_value(mol) -class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ... +class AugRDKit2DNormalized( + AugmentedMolecularProperty, FrozenPropertyAlias, pr.RDKit2DNormalized +): ... From 0f1c8eb673989f45673301a9efbc9ba98d6cd3e2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 12:37:25 +0200 Subject: [PATCH 133/224] base classes for seperate pooling of each type of node --- chebai_graph/models/base.py | 173 +++++++++++++++++++++++++++++++++++ chebai_graph/models/graph.py | 15 +-- 2 files changed, 176 insertions(+), 12 deletions(-) create mode 100644 chebai_graph/models/base.py diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py new file mode 100644 index 0000000..7b5f963 --- /dev/null +++ b/chebai_graph/models/base.py @@ -0,0 +1,173 @@ +from abc import ABC, abstractmethod + +import torch +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.structures import XYData +from torch_geometric.data import Data as GraphData +from torch_scatter import scatter_add + + +class GraphBaseNet(ChebaiBaseNet, ABC): + def _get_prediction_and_labels(self, data, labels, output): + return torch.sigmoid(output), labels.int() + + def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: + return batch.y.float() if batch.y is not None else None + + +class GraphNetWrapper(GraphBaseNet, ABC): + def __init__(self, config: dict, n_linear_layers, n_molecule_properties, **kwargs): + super().__init__(**kwargs) + self.gnn = self._get_gnn(config) + gnn_out_dim = config["out_dim"] if "out_dim" in config else config["hidden_dim"] + + self.lin_input_dim = self._get_lin_seq_input_dim( + gnn_out_dim=gnn_out_dim, + n_molecule_properties=n_molecule_properties, + ) + + self.lin_sequential: torch.nn.Sequential = self._get_linear_module_list( + n_linear_layers=n_linear_layers, + in_dim=self.lin_input_dim, + out_dim=self.out_dim, + ) + + @abstractmethod + def _get_gnn(self, config): + pass + + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + return gnn_out_dim + n_molecule_properties + + def _get_linear_module_list(self, n_linear_layers, in_dim, hidden_dim, out_dim): + if n_linear_layers < 1: + raise ValueError("n_linear_layers must be at least 1") + + layers = [] + if n_linear_layers == 1: + layers.append(torch.nn.Linear(in_dim, out_dim)) + else: + layers.append(torch.nn.Linear(in_dim, hidden_dim)) + for _ in range(n_linear_layers - 2): + layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + layers.append(torch.nn.Linear(hidden_dim, out_dim)) + + return torch.nn.Sequential(*layers) + + +class AugmentedNodePoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # atom_embeddings + molecule attributes + augmented_node_embeddings + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_atom_node = graph_data.is_atom_node.bool() # Boolean mask: shape [num_nodes] + is_augmented_node = ~is_atom_node + + node_embeddings = self.gnn(batch) + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + augmented_node_embeddings = node_embeddings[is_augmented_node] + augmented_node_batch = graph_data.batch[is_augmented_node] + + # Scatter add separately + graph_vec_atoms = scatter_add(atom_embeddings, atom_batch, dim=0) + graph_vec_augmented_nodes = scatter_add( + augmented_node_embeddings, augmented_node_batch, dim=0 + ) + + # Concatenate both + graph_vector = torch.cat( + [ + graph_vec_atoms, + graph_data.molecule_attr, + graph_vec_augmented_nodes, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodePoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # all_nodes_embeddings_except_graph_node + molecule attributes + graph_node_embedding + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_not_graph_node = ~is_graph_node + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + remaining_node_embedding = node_embeddings[is_not_graph_node] + remaining_node_batch = graph_data.batch[is_not_graph_node] + + # Scatter add separately + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + remaining_nodes_vec = scatter_add( + remaining_node_embedding, remaining_node_batch, dim=0 + ) + + # Concatenate both + graph_vector = torch.cat( + [ + remaining_nodes_vec, + graph_data.molecule_attr, + graph_node_vec, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeAugmentedNodePoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # atom_embeddings + molecule attributes + functional_group_node_embeddings + graph_node_embeddings + return gnn_out_dim + n_molecule_properties + gnn_out_dim + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + fg_node_embeddings = node_embeddings[is_fg_node] + fg_node_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) + fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + + # Concatenate both + graph_vector = torch.cat( + [ + atom_vec, + graph_data.molecule_attr, + fg_node_vec, + graph_node_vec, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) diff --git a/chebai_graph/models/graph.py b/chebai_graph/models/graph.py index ff53e07..3be0fdd 100644 --- a/chebai_graph/models/graph.py +++ b/chebai_graph/models/graph.py @@ -3,8 +3,6 @@ import torch import torch.nn.functional as F -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.structures import XYData from torch import nn from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData @@ -12,15 +10,9 @@ from chebai_graph.loss.pretraining import MaskPretrainingLoss -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - - -class GraphBaseNet(ChebaiBaseNet): - def _get_prediction_and_labels(self, data, labels, output): - return torch.sigmoid(output), labels.int() +from .base import GraphBaseNet - def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: - return batch.y.float() if batch.y is not None else None +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) class JCIGraphNet(GraphBaseNet): @@ -246,8 +238,7 @@ def forward(self, batch): for lin in self.linear_layers: a = self.gnn.activation(lin(graph_vector)) - a = self.final_layer(a) - return a + return self.final_layer(a) class ResGatedGraphConvNetPretrain(GraphBaseNet): From 4d31642bcfcb786409294a19df77243afd4e976f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 12:38:58 +0200 Subject: [PATCH 134/224] Revert "add unk tokens to prop files" This reverts commit 8353cdd7e8665671f77f9624bc10d3acd5330c78. --- .../bin/AtomCharge/indices_one_hot.txt | 1 - .../bin/AtomNumHs/indices_one_hot.txt | 1 - .../bin/BondType/indices_one_hot.txt | 1 - .../bin/NumAtomBonds/indices_one_hot.txt | 1 - .../properties/augmented_properties.py | 18 +++--------------- 5 files changed, 3 insertions(+), 19 deletions(-) diff --git a/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt index e929c5d..b0f2214 100644 --- a/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt @@ -11,4 +11,3 @@ 6 7 -5 -unk diff --git a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt index b2f1b65..f024a9d 100644 --- a/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt @@ -5,4 +5,3 @@ 1 5 6 -unk diff --git a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt index f36bdf7..97ae8be 100644 --- a/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt @@ -3,4 +3,3 @@ SINGLE AROMATIC TRIPLE DOUBLE -UNSPECIFIED diff --git a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt index 9e58e88..7036755 100644 --- a/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt @@ -9,4 +9,3 @@ 7 10 12 -unk diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 0474a42..c174340 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -154,14 +154,12 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): - DEFAULT_PROP_VAL = None - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): if isinstance(atom, Chem.rdchem.Atom): # Delegate to superclass method for atom return super().get_atom_value(atom) elif isinstance(atom, dict): - return self.DEFAULT_PROP_VAL + return None else: raise TypeError( f"Expected Chem.rdchem.Atom or dict, got {type(atom).__name__}" @@ -178,7 +176,6 @@ class AugAtomType(AugNodeValueDefaulter, pr.AtomType): # An undefined or wildcard atom. # A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). ... - DEFAULT_PROP_VAL = 0 class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): @@ -188,7 +185,6 @@ class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): # Currently, we return None which leads to zero-tensor for augmented nodes # But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. Consider graph node too. ... - DEFAULT_PROP_VAL = "unk" class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): @@ -197,7 +193,6 @@ class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes # Currently, we return None which leads to zero-tensor for augmented nodes ... - DEFAULT_PROP_VAL = "unk" class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): @@ -206,7 +201,6 @@ class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType # Currently, we return None which leads to zero-tensor for augmented nodes ... - DEFAULT_PROP_VAL = Chem.rdchem.HybridizationType.UNSPECIFIED class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): @@ -215,7 +209,6 @@ class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes # Currently, we return None which leads to zero-tensor for augmented nodes ... - DEFAULT_PROP_VAL = "unk" class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): @@ -303,14 +296,12 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): - DEFAULT_PROP_VAL = None - def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): if isinstance(bond, Chem.rdchem.Bond): # Delegate to superclass method for bond return super().get_bond_value(bond) elif isinstance(bond, dict): - return self.DEFAULT_PROP_VAL + return None else: raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") @@ -328,7 +319,6 @@ class AugBondType(AugBondValueDefaulter, pr.BondType): # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType # Currently, we return None which leads to zero-tensor for augmented nodes ... - DEFAULT_PROP_VAL = Chem.rdchem.BondType.UNSPECIFIED class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): @@ -346,6 +336,4 @@ def get_property_value(self, augmented_mol: Dict) -> list: return super().get_property_value(mol) -class AugRDKit2DNormalized( - AugmentedMolecularProperty, FrozenPropertyAlias, pr.RDKit2DNormalized -): ... +class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ... From 460fbbae7afd78ebf4797d328a739a010bdef822 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 12:58:32 +0200 Subject: [PATCH 135/224] activation for intermediate lin layers --- chebai_graph/models/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 7b5f963..53ab7aa 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -20,15 +20,17 @@ def __init__(self, config: dict, n_linear_layers, n_molecule_properties, **kwarg super().__init__(**kwargs) self.gnn = self._get_gnn(config) gnn_out_dim = config["out_dim"] if "out_dim" in config else config["hidden_dim"] - + self.activation = torch.nn.ELU self.lin_input_dim = self._get_lin_seq_input_dim( gnn_out_dim=gnn_out_dim, n_molecule_properties=n_molecule_properties, ) + lin_hidden_dim = kwargs.get("lin_hidden_dim", gnn_out_dim) self.lin_sequential: torch.nn.Sequential = self._get_linear_module_list( n_linear_layers=n_linear_layers, in_dim=self.lin_input_dim, + hidden_dim=lin_hidden_dim, out_dim=self.out_dim, ) @@ -48,8 +50,10 @@ def _get_linear_module_list(self, n_linear_layers, in_dim, hidden_dim, out_dim): layers.append(torch.nn.Linear(in_dim, out_dim)) else: layers.append(torch.nn.Linear(in_dim, hidden_dim)) + layers.append(self.activation()) for _ in range(n_linear_layers - 2): layers.append(torch.nn.Linear(hidden_dim, hidden_dim)) + layers.append(self.activation()) layers.append(torch.nn.Linear(hidden_dim, out_dim)) return torch.nn.Sequential(*layers) @@ -80,7 +84,7 @@ def forward(self, batch): augmented_node_embeddings, augmented_node_batch, dim=0 ) - # Concatenate both + # Concatenate all graph_vector = torch.cat( [ graph_vec_atoms, @@ -118,7 +122,7 @@ def forward(self, batch): remaining_node_embedding, remaining_node_batch, dim=0 ) - # Concatenate both + # Concatenate all graph_vector = torch.cat( [ remaining_nodes_vec, @@ -159,7 +163,7 @@ def forward(self, batch): atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) - # Concatenate both + # Concatenate all graph_vector = torch.cat( [ atom_vec, From 623b474e6e2e7ca0f064c346dc9b496ebbe70007 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 14:09:17 +0200 Subject: [PATCH 136/224] base class for fg node only pooling --- chebai_graph/models/base.py | 42 ++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 53ab7aa..ab773a8 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -135,7 +135,47 @@ def forward(self, batch): return self.lin_sequential(graph_vector) -class GraphNodeAugmentedNodePoolingNet(GraphNetWrapper, ABC): +class FGNodePoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # all_nodes_embeddings_except_fg_nodes + molecule attributes + fg_node_embedding + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + is_remaining_node = ~is_fg_node + + node_embeddings = self.gnn(batch) + + remaining_node_embedding = node_embeddings[is_remaining_node] + remaining_node_batch = graph_data.batch[is_remaining_node] + + fg_node_embeddings = node_embeddings[is_fg_node] + fg_node_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + remaining_node_vec = scatter_add( + remaining_node_embedding, remaining_node_batch, dim=0 + ) + fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [ + remaining_node_vec, + graph_data.molecule_attr, + fg_node_vec, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeFGNodePoolingNet(GraphNetWrapper, ABC): def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): # atom_embeddings + molecule attributes + functional_group_node_embeddings + graph_node_embeddings return gnn_out_dim + n_molecule_properties + gnn_out_dim + gnn_out_dim From caeb9319ecf64a83c2910a39e1d25fa93c4019bd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 14:09:54 +0200 Subject: [PATCH 137/224] model wrapper for resgated --- chebai_graph/models/model_wrappers.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 chebai_graph/models/model_wrappers.py diff --git a/chebai_graph/models/model_wrappers.py b/chebai_graph/models/model_wrappers.py new file mode 100644 index 0000000..7bdbc00 --- /dev/null +++ b/chebai_graph/models/model_wrappers.py @@ -0,0 +1,36 @@ +from abc import ABC + +import torch +from torch_geometric import nn as tgnn + +from .base import GraphNetWrapper + + +class ResGatedModelWrapper(GraphNetWrapper, ABC): + def _get_gnn(self, config): + in_length = config["in_length"] + hidden_length = config["hidden_length"] + dropout_rate = config["dropout_rate"] + n_atom_properties = int(config["n_atom_properties"]) + n_bond_properties = int(config["n_bond_properties"]) + n_conv_layers = int(config["n_conv_layers"]) + + convs = torch.nn.ModuleList() + for i in range(n_conv_layers): + if i == 0: + convs.append( + tgnn.ResGatedGraphConv( + n_atom_properties, + in_length, + # dropout=dropout_rate, + edge_dim=n_bond_properties, + ) + ) + convs.append( + tgnn.ResGatedGraphConv(in_length, in_length, edge_dim=n_bond_properties) + ) + convs.append( + tgnn.ResGatedGraphConv(in_length, hidden_length, edge_dim=n_bond_properties) + ) + + return convs From 0a393904f41e60ee1c013d748aac27f04ee54f10 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 14:11:01 +0200 Subject: [PATCH 138/224] classes for augmented graph with poolings --- chebai_graph/models/__init__.py | 15 +++++++++++++-- chebai_graph/models/augmented.py | 33 ++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 chebai_graph/models/augmented.py diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index d091527..97de9ce 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,4 +1,15 @@ from ._gat import GATModelWrapper -from .graph import ResGatedAugmentedGraphPred +from .augmented import ( + ResGatedAugNodePoolGraphPred, + ResGatedFGNodePoolGraphPred, + ResGatedGraphNodeFGNodePoolGraphPred, + ResGatedGraphNodePoolGraphPred, +) -__all__ = ["GATModelWrapper", "ResGatedAugmentedGraphPred"] +__all__ = [ + "GATModelWrapper", + "ResGatedAugNodePoolGraphPred", + "ResGatedGraphNodeFGNodePoolGraphPred", + "ResGatedGraphNodePoolGraphPred", + "ResGatedFGNodePoolGraphPred", +] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py new file mode 100644 index 0000000..1d2552e --- /dev/null +++ b/chebai_graph/models/augmented.py @@ -0,0 +1,33 @@ +from .base import ( + AugmentedNodePoolingNet, + FGNodePoolingNet, + GraphNodeFGNodePoolingNet, + GraphNodePoolingNet, +) +from .model_wrappers import ResGatedModelWrapper + + +class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedModelWrapper): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedAugNodePoolGraphPred" + + +class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedModelWrapper): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedGraphNodePoolGraphPred" + + +class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedModelWrapper): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedFGNodePoolGraphPred" + + +class ResGatedGraphNodeFGNodePoolGraphPred( + GraphNodeFGNodePoolingNet, ResGatedModelWrapper +): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedGraphNodeFGNodePoolGraphPred" From 192cb1d650aeb2bfa43b05e0222529ed1c11d65c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 14:59:59 +0200 Subject: [PATCH 139/224] resgated model separate file --- chebai_graph/models/model_wrappers.py | 31 +------------ chebai_graph/models/resgated.py | 64 +++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 29 deletions(-) create mode 100644 chebai_graph/models/resgated.py diff --git a/chebai_graph/models/model_wrappers.py b/chebai_graph/models/model_wrappers.py index 7bdbc00..b274815 100644 --- a/chebai_graph/models/model_wrappers.py +++ b/chebai_graph/models/model_wrappers.py @@ -1,36 +1,9 @@ from abc import ABC -import torch -from torch_geometric import nn as tgnn - from .base import GraphNetWrapper +from .resgated import ResGatedGraphConvNetBase class ResGatedModelWrapper(GraphNetWrapper, ABC): def _get_gnn(self, config): - in_length = config["in_length"] - hidden_length = config["hidden_length"] - dropout_rate = config["dropout_rate"] - n_atom_properties = int(config["n_atom_properties"]) - n_bond_properties = int(config["n_bond_properties"]) - n_conv_layers = int(config["n_conv_layers"]) - - convs = torch.nn.ModuleList() - for i in range(n_conv_layers): - if i == 0: - convs.append( - tgnn.ResGatedGraphConv( - n_atom_properties, - in_length, - # dropout=dropout_rate, - edge_dim=n_bond_properties, - ) - ) - convs.append( - tgnn.ResGatedGraphConv(in_length, in_length, edge_dim=n_bond_properties) - ) - convs.append( - tgnn.ResGatedGraphConv(in_length, hidden_length, edge_dim=n_bond_properties) - ) - - return convs + return ResGatedGraphConvNetBase(config=config) diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py new file mode 100644 index 0000000..affbfa2 --- /dev/null +++ b/chebai_graph/models/resgated.py @@ -0,0 +1,64 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch_geometric import nn as tgnn +from torch_geometric.data import Data as GraphData + +from .base import GraphBaseNet + + +class ResGatedGraphConvNetBase(GraphBaseNet): + """GNN that supports edge attributes""" + + NAME = "ResGatedGraphConvNetBase" + + def __init__(self, config: dict, **kwargs): + super().__init__(**kwargs) + + self.in_length = config["in_length"] + self.hidden_length = config["hidden_length"] + self.dropout_rate = config["dropout_rate"] + self.n_conv_layers = config["n_conv_layers"] + self.n_atom_properties = int(config["n_atom_properties"]) + self.n_bond_properties = int(config["n_bond_properties"]) + + self.activation = F.elu + self.dropout = nn.Dropout(self.dropout_rate) + + self.convs = torch.nn.ModuleList([]) + for i in range(self.n_conv_layers): + if i == 0: + self.convs.append( + tgnn.ResGatedGraphConv( + self.n_atom_properties, + self.in_length, + # dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, + ) + ) + self.convs.append( + tgnn.ResGatedGraphConv( + self.in_length, self.in_length, edge_dim=self.n_bond_properties + ) + ) + self.final_conv = tgnn.ResGatedGraphConv( + self.in_length, self.hidden_length, edge_dim=self.n_bond_properties + ) + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + a = graph_data.x.float() + # a = self.embedding(a) + + for conv in self.convs: + assert isinstance(conv, tgnn.ResGatedGraphConv) + a = self.activation( + conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) + ) + a = self.activation( + self.final_conv( + a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr + ) + ) + return a From ad80e84ab733072f3139c7e5807613321df49644 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 15:13:04 +0200 Subject: [PATCH 140/224] add normal regated without pool --- chebai_graph/models/__init__.py | 2 ++ chebai_graph/models/augmented.py | 2 +- chebai_graph/models/base.py | 9 +++++++++ chebai_graph/models/model_wrappers.py | 9 --------- chebai_graph/models/resgated.py | 13 ++++++++++++- 5 files changed, 24 insertions(+), 11 deletions(-) delete mode 100644 chebai_graph/models/model_wrappers.py diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index 97de9ce..2b3696d 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -5,9 +5,11 @@ ResGatedGraphNodeFGNodePoolGraphPred, ResGatedGraphNodePoolGraphPred, ) +from .resgated import ResGatedGraphPred __all__ = [ "GATModelWrapper", + "ResGatedGraphPred", "ResGatedAugNodePoolGraphPred", "ResGatedGraphNodeFGNodePoolGraphPred", "ResGatedGraphNodePoolGraphPred", diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index 1d2552e..5e91f2d 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -4,7 +4,7 @@ GraphNodeFGNodePoolingNet, GraphNodePoolingNet, ) -from .model_wrappers import ResGatedModelWrapper +from .resgated import ResGatedModelWrapper class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedModelWrapper): diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index ab773a8..fd1a38d 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -58,6 +58,15 @@ def _get_linear_module_list(self, n_linear_layers, in_dim, hidden_dim, out_dim): return torch.nn.Sequential(*layers) + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + a = self.gnn(batch) + a = scatter_add(a, graph_data.batch, dim=0) + a = torch.cat([a, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(a) + class AugmentedNodePoolingNet(GraphNetWrapper, ABC): def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): diff --git a/chebai_graph/models/model_wrappers.py b/chebai_graph/models/model_wrappers.py deleted file mode 100644 index b274815..0000000 --- a/chebai_graph/models/model_wrappers.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC - -from .base import GraphNetWrapper -from .resgated import ResGatedGraphConvNetBase - - -class ResGatedModelWrapper(GraphNetWrapper, ABC): - def _get_gnn(self, config): - return ResGatedGraphConvNetBase(config=config) diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py index affbfa2..fb421d3 100644 --- a/chebai_graph/models/resgated.py +++ b/chebai_graph/models/resgated.py @@ -1,10 +1,12 @@ +from abc import ABC + import torch import torch.nn.functional as F from torch import nn from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData -from .base import GraphBaseNet +from .base import GraphBaseNet, GraphNetWrapper class ResGatedGraphConvNetBase(GraphBaseNet): @@ -62,3 +64,12 @@ def forward(self, batch): ) ) return a + + +class ResGatedModelWrapper(GraphNetWrapper, ABC): + def _get_gnn(self, config): + return ResGatedGraphConvNetBase(config=config) + + +class ResGatedGraphPred(GraphNetWrapper, ResGatedModelWrapper): + NAME = "ResGatedGraphPred" From 35f07c84ecdaf1c1637988b80e1fc90fd058e02a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 15:27:12 +0200 Subject: [PATCH 141/224] add ruff to pre-commit --- .pre-commit-config.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 108b91d..e32d80c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,3 +23,9 @@ repos: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.1 + hooks: + - id: ruff + args: [] # No --fix, disables formatting From 3041d4066ca299e1b5fdc91c7ce6a354b806040a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 15:33:05 +0200 Subject: [PATCH 142/224] fix ruff errors --- .../preprocessing/property_encoder.py | 2 +- .../preprocessing/reader/augmented_reader.py | 1 - .../utils/visualize_augmented_molecule.py | 34 +++++++++---------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index d5f833a..5d6b386 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -125,7 +125,7 @@ def get_encoding_length(self) -> int: @property def name(self): - return f"one_hot" + return "one_hot" def on_start(self, property_values): """To get correct number of classes during encoding, cache unique tokens beforehand""" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 6c8ba7a..14f5ab3 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -383,7 +383,6 @@ def _construct_fg_to_atom_structure( for fg_smiles, fg_group in structure.items(): fg_to_atoms_map[self._num_of_nodes] = fg_group is_ring_fg = fg_group["is_ring_fg"] - is_alkyl = 0 connected_atoms = [] # Build edge index for fg to atom nodes connections diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 307585b..d16f2dd 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -10,16 +10,16 @@ from rdkit.Chem.Draw import rdMolDraw2D from torch import Tensor -from chebai_graph.preprocessing.properties.constants import * +from chebai_graph.preprocessing.properties import constants as k from chebai_graph.preprocessing.reader import GraphFGAugmentorReader matplotlib.use("TkAgg") EDGE_COLOR_MAP = { - WITHIN_ATOMS_EDGE: "#1f77b4", - ATOM_FG_EDGE: "#9467bd", - WITHIN_FG_EDGE: "#ff7f0e", - FG_GRAPHNODE_EDGE: "#2ca02c", + k.WITHIN_ATOMS_EDGE: "#1f77b4", + k.ATOM_FG_EDGE: "#9467bd", + k.WITHIN_FG_EDGE: "#ff7f0e", + k.FG_GRAPHNODE_EDGE: "#2ca02c", } NODE_COLOR_MAP = { @@ -90,22 +90,22 @@ def _create_graph( src_nodes, tgt_nodes = edge_index.tolist() with_atom_edges = { f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" - for bond in augmented_graph_edges[WITHIN_ATOMS_EDGE].GetBonds() + for bond in augmented_graph_edges[k.WITHIN_ATOMS_EDGE].GetBonds() } - atom_fg_edges = set(augmented_graph_edges[ATOM_FG_EDGE]) - within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) - fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) + atom_fg_edges = set(augmented_graph_edges[k.ATOM_FG_EDGE]) + within_fg_edges = set(augmented_graph_edges[k.WITHIN_FG_EDGE]) + fg_graph_edges = set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE]) for src, tgt in zip(src_nodes, tgt_nodes): undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} if undirected_edge_set & with_atom_edges: - edge_type = WITHIN_ATOMS_EDGE + edge_type = k.WITHIN_ATOMS_EDGE elif undirected_edge_set & atom_fg_edges: - edge_type = ATOM_FG_EDGE + edge_type = k.ATOM_FG_EDGE elif undirected_edge_set & within_fg_edges: - edge_type = WITHIN_FG_EDGE + edge_type = k.WITHIN_FG_EDGE elif undirected_edge_set & fg_graph_edges: - edge_type = FG_GRAPHNODE_EDGE + edge_type = k.FG_GRAPHNODE_EDGE else: raise ValueError("Unexpected edge type") G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) @@ -266,10 +266,10 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: # Collect edges by type edge_type_to_edges = { - WITHIN_ATOMS_EDGE: [], - ATOM_FG_EDGE: [], - WITHIN_FG_EDGE: [], - FG_GRAPHNODE_EDGE: [], + k.WITHIN_ATOMS_EDGE: [], + k.ATOM_FG_EDGE: [], + k.WITHIN_FG_EDGE: [], + k.FG_GRAPHNODE_EDGE: [], } for src, tgt, data in G.edges(data=True): edge_type_to_edges[data["edge_type"]].append((src, tgt)) From e95ea79ecb550ee97910bad9507f6abde6e2b63a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 15:36:10 +0200 Subject: [PATCH 143/224] add ruff to action workflow --- .github/workflows/black.yml | 10 ---------- .github/workflows/lint.yml | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) delete mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index b04fb15..0000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: psf/black@stable diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..1b63c41 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black ruff + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix From c3e48b71ee23eba7e09e26d3e37a1ac876c74a3a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 19:53:18 +0200 Subject: [PATCH 144/224] add gat model --- chebai_graph/models/__init__.py | 4 +-- chebai_graph/models/_gat.py | 59 -------------------------------- chebai_graph/models/augmented.py | 10 +++--- chebai_graph/models/base.py | 17 +++++++-- chebai_graph/models/gat.py | 43 +++++++++++++++++++++++ chebai_graph/models/resgated.py | 26 +++++--------- configs/model/gat.yml | 11 +++--- 7 files changed, 79 insertions(+), 91 deletions(-) delete mode 100644 chebai_graph/models/_gat.py create mode 100644 chebai_graph/models/gat.py diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index 2b3696d..71b94e7 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,14 +1,14 @@ -from ._gat import GATModelWrapper from .augmented import ( ResGatedAugNodePoolGraphPred, ResGatedFGNodePoolGraphPred, ResGatedGraphNodeFGNodePoolGraphPred, ResGatedGraphNodePoolGraphPred, ) +from .gat import GATGraphPred from .resgated import ResGatedGraphPred __all__ = [ - "GATModelWrapper", + "GATGraphPred", "ResGatedGraphPred", "ResGatedAugNodePoolGraphPred", "ResGatedGraphNodeFGNodePoolGraphPred", diff --git a/chebai_graph/models/_gat.py b/chebai_graph/models/_gat.py deleted file mode 100644 index 5dbed73..0000000 --- a/chebai_graph/models/_gat.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -import torch.nn.functional as F -from torch_geometric.data import Data as GraphData -from torch_geometric.nn.models import GAT -from torch_scatter import scatter_add - -from .graph import GraphBaseNet - - -class GATModelWrapper(GraphBaseNet): - NAME = "GATModel" - - def __init__(self, config: dict, **kwargs): - super().__init__(**kwargs) - - self._hidden_length = int(config.pop("hidden_length")) - self._dropout_rate = float(config.pop("dropout_rate", 0.1)) - self._n_conv_layers = int(config.pop("n_conv_layers", 3)) - self._n_linear_layers = int(config.pop("n_linear_layers", 3)) - self._n_atom_properties = int(config.pop("n_atom_properties")) - self._n_bond_properties = int(config.pop("n_bond_properties")) - self._n_molecule_properties = int(config.pop("n_molecule_properties")) - self._gat = GAT( - in_channels=self._n_atom_properties, - hidden_channels=self._hidden_length, - num_layers=self._n_conv_layers, - dropout=self._dropout_rate, - edge_dim=self._n_bond_properties, - **config, - ) - - self._ffn_activation = F.elu - - self.linear_layers = torch.nn.ModuleList( - [ - torch.nn.Linear( - self._hidden_length - + (self._n_molecule_properties if i == 0 else 0), - self._hidden_length, - ) - for i in range(self._n_linear_layers - 1) - ] - ) - self.final_layer = torch.nn.Linear(self._hidden_length, self.out_dim) - - def forward(self, batch): - graph_data = batch["features"][0] - assert isinstance(graph_data, GraphData) - x = graph_data.x.float() - a = self._gat.forward( - x=x, edge_index=graph_data.edge_index.long(), edge_attr=graph_data.edge_attr - ) - a = scatter_add(a, graph_data.batch, dim=0) - - a = torch.cat([a, graph_data.molecule_attr], dim=1) - - for lin in self.linear_layers: - a = self._ffn_activation(lin(a)) - return self.final_layer(a) diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index 5e91f2d..b9395e2 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -4,29 +4,29 @@ GraphNodeFGNodePoolingNet, GraphNodePoolingNet, ) -from .resgated import ResGatedModelWrapper +from .resgated import ResGatedGraphPred -class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedModelWrapper): +class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): """GNN for graph-level prediction for augmented graphs""" NAME = "ResGatedAugNodePoolGraphPred" -class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedModelWrapper): +class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred): """GNN for graph-level prediction for augmented graphs""" NAME = "ResGatedGraphNodePoolGraphPred" -class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedModelWrapper): +class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedGraphPred): """GNN for graph-level prediction for augmented graphs""" NAME = "ResGatedFGNodePoolGraphPred" class ResGatedGraphNodeFGNodePoolGraphPred( - GraphNodeFGNodePoolingNet, ResGatedModelWrapper + GraphNodeFGNodePoolingNet, ResGatedGraphPred ): """GNN for graph-level prediction for augmented graphs""" diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index fd1a38d..f16e812 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -15,11 +15,25 @@ def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: return batch.y.float() if batch.y is not None else None +class GraphModelBase(torch.nn.Module, ABC): + """Base class for graph-based models with a configuration dictionary.""" + + def __init__(self, config: dict, **kwargs): + super().__init__(**kwargs) + self.hidden_length = int(config["hidden_length"]) + self.dropout_rate = float(config["dropout_rate"]) + self.n_conv_layers = int(config["n_conv_layers"]) + self.n_atom_properties = int(config["n_atom_properties"]) + self.n_bond_properties = int(config["n_bond_properties"]) + + class GraphNetWrapper(GraphBaseNet, ABC): def __init__(self, config: dict, n_linear_layers, n_molecule_properties, **kwargs): super().__init__(**kwargs) self.gnn = self._get_gnn(config) - gnn_out_dim = config["out_dim"] if "out_dim" in config else config["hidden_dim"] + gnn_out_dim = ( + config["out_dim"] if "out_dim" in config else config["hidden_length"] + ) self.activation = torch.nn.ELU self.lin_input_dim = self._get_lin_seq_input_dim( gnn_out_dim=gnn_out_dim, @@ -64,7 +78,6 @@ def forward(self, batch): a = self.gnn(batch) a = scatter_add(a, graph_data.batch, dim=0) a = torch.cat([a, graph_data.molecule_attr], dim=1) - return self.lin_sequential(a) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py new file mode 100644 index 0000000..cb4e2f4 --- /dev/null +++ b/chebai_graph/models/gat.py @@ -0,0 +1,43 @@ +import torch +from torch.nn import ELU +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.models import GAT + +from .base import GraphModelBase, GraphNetWrapper + + +class GATGraphConvNetBase(GraphModelBase): + def __init__(self, config, **kwargs): + super().__init__(config=config, **kwargs) + self.heads = int(config["heads"]) + self.v2 = bool(config["v2"]) + self.activation = ELU() # instantiate once + self.gat = GAT( + in_channels=self.n_atom_properties, + hidden_channels=self.hidden_length, + num_layers=self.n_conv_layers, + dropout=self.dropout_rate, + edge_dim=self.n_bond_properties, + heads=self.heads, + v2=self.v2, + act=ELU, + ) + + def forward(self, batch: dict) -> torch.Tensor: + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + a = self.gat( + x=graph_data.x.float(), + edge_index=graph_data.edge_index, + edge_attr=graph_data.edge_attr, + ) + + return self.activation(a) + + +class GATGraphPred(GraphNetWrapper): + NAME = "GATGraphPred" + + def _get_gnn(self, config): + return GATGraphConvNetBase(config=config) diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py index fb421d3..5667be6 100644 --- a/chebai_graph/models/resgated.py +++ b/chebai_graph/models/resgated.py @@ -1,28 +1,20 @@ -from abc import ABC - import torch import torch.nn.functional as F from torch import nn from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData -from .base import GraphBaseNet, GraphNetWrapper +from .base import GraphModelBase, GraphNetWrapper -class ResGatedGraphConvNetBase(GraphBaseNet): +class ResGatedGraphConvNetBase(GraphModelBase): """GNN that supports edge attributes""" NAME = "ResGatedGraphConvNetBase" - def __init__(self, config: dict, **kwargs): - super().__init__(**kwargs) - - self.in_length = config["in_length"] - self.hidden_length = config["hidden_length"] - self.dropout_rate = config["dropout_rate"] - self.n_conv_layers = config["n_conv_layers"] - self.n_atom_properties = int(config["n_atom_properties"]) - self.n_bond_properties = int(config["n_bond_properties"]) + def __init__(self, config, **kwargs): + super().__init__(config=config, **kwargs) + self.in_length = int(config["in_length"]) self.activation = F.elu self.dropout = nn.Dropout(self.dropout_rate) @@ -66,10 +58,8 @@ def forward(self, batch): return a -class ResGatedModelWrapper(GraphNetWrapper, ABC): +class ResGatedGraphPred(GraphNetWrapper): + NAME = "ResGatedGraphPred" + def _get_gnn(self, config): return ResGatedGraphConvNetBase(config=config) - - -class ResGatedGraphPred(GraphNetWrapper, ResGatedModelWrapper): - NAME = "ResGatedGraphPred" diff --git a/configs/model/gat.yml b/configs/model/gat.yml index 0cc9b24..688edb2 100644 --- a/configs/model/gat.yml +++ b/configs/model/gat.yml @@ -1,14 +1,15 @@ -class_path: chebai_graph.models.GATModelWrapper +class_path: chebai_graph.models.GATGraphPred init_args: optimizer_kwargs: lr: 1e-3 config: hidden_length: 512 - dropout_rate: 0.1 + dropout_rate: 0 n_conv_layers: 3 heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) - # v2: True # -- to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv - n_linear_layers: 3 + v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 + + n_molecule_properties: 200 + n_linear_layers: 3 From 7efbea5609715f52d8f42db92110046e6a426c98 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 2 Jul 2025 23:12:57 +0200 Subject: [PATCH 145/224] Create resgated.yml --- configs/model/resgated.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 configs/model/resgated.yml diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml new file mode 100644 index 0000000..83746ae --- /dev/null +++ b/configs/model/resgated.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_length: 256 + hidden_length: 512 + dropout_rate: 0.1 + n_conv_layers: 3 + n_atom_properties: 158 + n_bond_properties: 7 + n_molecule_properties: 200 + n_linear_layers: 2 From f52d6e3d125bbebfd23001efa55aa21641a2e41c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 00:23:42 +0200 Subject: [PATCH 146/224] more class for ablation studies --- chebai_graph/models/__init__.py | 4 ++ chebai_graph/models/augmented.py | 18 ++++++++ chebai_graph/models/base.py | 77 ++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index 71b94e7..0c578e4 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,7 +1,9 @@ from .augmented import ( ResGatedAugNodePoolGraphPred, + ResGatedFGNodeNoGraphNodeGraphPred, ResGatedFGNodePoolGraphPred, ResGatedGraphNodeFGNodePoolGraphPred, + ResGatedGraphNodeNoFGNodeGraphPred, ResGatedGraphNodePoolGraphPred, ) from .gat import GATGraphPred @@ -10,8 +12,10 @@ __all__ = [ "GATGraphPred", "ResGatedGraphPred", + "ResGatedFGNodeNoGraphNodeGraphPred", "ResGatedAugNodePoolGraphPred", "ResGatedGraphNodeFGNodePoolGraphPred", "ResGatedGraphNodePoolGraphPred", + "ResGatedGraphNodeNoFGNodeGraphPred", "ResGatedFGNodePoolGraphPred", ] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index b9395e2..a7acd91 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -1,7 +1,9 @@ from .base import ( AugmentedNodePoolingNet, FGNodePoolingNet, + FGNodePoolingNoGraphNodeNet, GraphNodeFGNodePoolingNet, + GraphNodeNoFGNodePoolingNet, GraphNodePoolingNet, ) from .resgated import ResGatedGraphPred @@ -31,3 +33,19 @@ class ResGatedGraphNodeFGNodePoolGraphPred( """GNN for graph-level prediction for augmented graphs""" NAME = "ResGatedGraphNodeFGNodePoolGraphPred" + + +class ResGatedGraphNodeNoFGNodeGraphPred( + GraphNodeNoFGNodePoolingNet, ResGatedGraphPred +): + """GNN for graph-level prediction for augmented graphs without FG nodes""" + + NAME = "ResGatedGraphNodeNoFGNodeGraphPred" + + +class ResGatedFGNodeNoGraphNodeGraphPred( + FGNodePoolingNoGraphNodeNet, ResGatedGraphPred +): + """GNN for graph-level prediction for augmented graphs without FG nodes""" + + NAME = "ResGatedFGNodeNoGraphNodeGraphPred" diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index f16e812..bef9b05 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -237,3 +237,80 @@ def forward(self, batch): ) return self.lin_sequential(graph_vector) + + +class FGNodePoolingNoGraphNodeNet(GraphNetWrapper, ABC): + """Graph Node not considered here in any computation""" + + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # atom_embeddings + molecule attributes + functional_group_node_embeddings + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + + node_embeddings = self.gnn(batch) + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + fg_node_embeddings = node_embeddings[is_fg_node] + fg_node_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) + fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [ + atom_vec, + graph_data.molecule_attr, + fg_node_vec, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeNoFGNodePoolingNet(GraphNetWrapper, ABC): + """Functional Group Nodes not considered here in any computation""" + + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + # atom_embeddings + molecule attributes + graph_node_embeddings + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch): + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + atom_embeddings = node_embeddings[is_atom_node] + atom_batch = graph_data.batch[is_atom_node] + + # Scatter add separately + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [ + atom_vec, + graph_data.molecule_attr, + graph_node_vec, + ], + dim=1, + ) + + return self.lin_sequential(graph_vector) From 55c4a972639a5140d2c32e4af7da489c2f48e158 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 13:15:10 +0200 Subject: [PATCH 147/224] add graph node mask --- chebai_graph/preprocessing/datasets/chebi.py | 4 ++++ chebai_graph/preprocessing/reader/augmented_reader.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 12cf568..ca07435 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -182,12 +182,16 @@ def _merge_props_into_base(self, row): is_atom_node = ( geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None ) + is_graph_node = ( + geom_data.is_graph_node if hasattr(geom_data, "is_graph_node") else None + ) return GeomData( x=x, edge_index=geom_data.edge_index, edge_attr=edge_attr, molecule_attr=molecule_attr, is_atom_node=is_atom_node, + is_graph_node=is_graph_node, ) def load_processed_data_from_file(self, filename): diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 14f5ab3..e6ffa32 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -205,9 +205,15 @@ def _read_data(self, smiles: str) -> GeomData | None: is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() is_atom_mask[:NUM_ATOM_NODES] = True + is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) + is_graph_node[-1] = True return GeomData( - x=x, edge_index=edge_index, edge_attr=edge_attr, is_atom_node=is_atom_mask + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + is_atom_node=is_atom_mask, + is_graph_node=is_graph_node, ) def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: From d8059c91da9621f3e38d2ff662c7ca466de6da88 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 13:45:46 +0200 Subject: [PATCH 148/224] few pooling single only class for ablation study --- chebai_graph/models/base.py | 135 +++++++++++++++++++++++++----------- 1 file changed, 94 insertions(+), 41 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index bef9b05..016cbe3 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -94,24 +94,24 @@ def forward(self, batch): node_embeddings = self.gnn(batch) - atom_embeddings = node_embeddings[is_atom_node] - atom_batch = graph_data.batch[is_atom_node] + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] - augmented_node_embeddings = node_embeddings[is_augmented_node] - augmented_node_batch = graph_data.batch[is_augmented_node] + augmented_nodes_embeddings = node_embeddings[is_augmented_node] + augmented_nodes_batch = graph_data.batch[is_augmented_node] # Scatter add separately - graph_vec_atoms = scatter_add(atom_embeddings, atom_batch, dim=0) - graph_vec_augmented_nodes = scatter_add( - augmented_node_embeddings, augmented_node_batch, dim=0 + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + aug_nodes_vec = scatter_add( + augmented_nodes_embeddings, augmented_nodes_batch, dim=0 ) # Concatenate all graph_vector = torch.cat( [ - graph_vec_atoms, + atoms_vec, graph_data.molecule_attr, - graph_vec_augmented_nodes, + aug_nodes_vec, ], dim=1, ) @@ -135,13 +135,13 @@ def forward(self, batch): graph_node_embedding = node_embeddings[is_graph_node] graph_node_batch = graph_data.batch[is_graph_node] - remaining_node_embedding = node_embeddings[is_not_graph_node] - remaining_node_batch = graph_data.batch[is_not_graph_node] + remaining_nodes_embedding = node_embeddings[is_not_graph_node] + remaining_nodes_batch = graph_data.batch[is_not_graph_node] # Scatter add separately graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) remaining_nodes_vec = scatter_add( - remaining_node_embedding, remaining_node_batch, dim=0 + remaining_nodes_embedding, remaining_nodes_batch, dim=0 ) # Concatenate all @@ -172,24 +172,24 @@ def forward(self, batch): node_embeddings = self.gnn(batch) - remaining_node_embedding = node_embeddings[is_remaining_node] - remaining_node_batch = graph_data.batch[is_remaining_node] + remaining_nodes_embedding = node_embeddings[is_remaining_node] + remaining_nodes_batch = graph_data.batch[is_remaining_node] - fg_node_embeddings = node_embeddings[is_fg_node] - fg_node_batch = graph_data.batch[is_fg_node] + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] # Scatter add separately - remaining_node_vec = scatter_add( - remaining_node_embedding, remaining_node_batch, dim=0 + remaining_nodes_vec = scatter_add( + remaining_nodes_embedding, remaining_nodes_batch, dim=0 ) - fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) # Concatenate all graph_vector = torch.cat( [ - remaining_node_vec, + remaining_nodes_vec, graph_data.molecule_attr, - fg_node_vec, + fg_nodes_vec, ], dim=1, ) @@ -214,23 +214,23 @@ def forward(self, batch): graph_node_embedding = node_embeddings[is_graph_node] graph_node_batch = graph_data.batch[is_graph_node] - atom_embeddings = node_embeddings[is_atom_node] - atom_batch = graph_data.batch[is_atom_node] + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] - fg_node_embeddings = node_embeddings[is_fg_node] - fg_node_batch = graph_data.batch[is_fg_node] + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] # Scatter add separately graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) - atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) - fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) # Concatenate all graph_vector = torch.cat( [ - atom_vec, + atoms_vec, graph_data.molecule_attr, - fg_node_vec, + fg_nodes_vec, graph_node_vec, ], dim=1, @@ -255,22 +255,22 @@ def forward(self, batch): node_embeddings = self.gnn(batch) - atom_embeddings = node_embeddings[is_atom_node] - atom_batch = graph_data.batch[is_atom_node] + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] - fg_node_embeddings = node_embeddings[is_fg_node] - fg_node_batch = graph_data.batch[is_fg_node] + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] # Scatter add separately - atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) - fg_node_vec = scatter_add(fg_node_embeddings, fg_node_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) # Concatenate all graph_vector = torch.cat( [ - atom_vec, + atoms_vec, graph_data.molecule_attr, - fg_node_vec, + fg_nodes_vec, ], dim=1, ) @@ -296,17 +296,17 @@ def forward(self, batch): graph_node_embedding = node_embeddings[is_graph_node] graph_node_batch = graph_data.batch[is_graph_node] - atom_embeddings = node_embeddings[is_atom_node] - atom_batch = graph_data.batch[is_atom_node] + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] # Scatter add separately graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) - atom_vec = scatter_add(atom_embeddings, atom_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) # Concatenate all graph_vector = torch.cat( [ - atom_vec, + atoms_vec, graph_data.molecule_attr, graph_node_vec, ], @@ -314,3 +314,56 @@ def forward(self, batch): ) return self.lin_sequential(graph_vector) + + +class AugmentedOnlyPoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + return gnn_out_dim + n_molecule_properties + + def forward(self, batch): + graph_data = batch["features"][0] + is_atom_node = graph_data.is_atom_node.bool() + augmented_nodes_embeddings = self.gnn(batch)[~is_atom_node] + augmented_nodes_batch = graph_data.batch[~is_atom_node] + + aug_nodes_vec = scatter_add( + augmented_nodes_embeddings, augmented_nodes_batch, dim=0 + ) + graph_vector = torch.cat([aug_nodes_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) + + +class FGOnlyPoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + return gnn_out_dim + n_molecule_properties + + def forward(self, batch): + graph_data = batch["features"][0] + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + fg_nodes_embeddings = self.gnn(batch)[~is_fg_node] + fg_nodes_batch = graph_data.batch[~is_fg_node] + + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + graph_vector = torch.cat([fg_nodes_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) + + +class GraphNodeOnlyPoolingNet(GraphNetWrapper, ABC): + def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + return gnn_out_dim + n_molecule_properties + + def forward(self, batch): + graph_data = batch["features"][0] + is_graph_node = graph_data.is_graph_node.bool() + + graph_node_embedding = self.gnn(batch)[~is_graph_node] + graph_node_batch = graph_data.batch[~is_graph_node] + + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + graph_vector = torch.cat([graph_node_vec, graph_data.molecule_attr], dim=1) + + return self.lin_sequential(graph_vector) From 947f53691e24b11ad624a8489f1e8fee90ab8471 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Jul 2025 15:01:41 +0200 Subject: [PATCH 149/224] resgated new ablation study classes --- chebai_graph/models/__init__.py | 6 ++++++ chebai_graph/models/augmented.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index 0c578e4..cf1796b 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,9 +1,12 @@ from .augmented import ( ResGatedAugNodePoolGraphPred, + ResGatedAugOnlyPoolGraphPred, ResGatedFGNodeNoGraphNodeGraphPred, ResGatedFGNodePoolGraphPred, + ResGatedFGOnlyPoolGraphPred, ResGatedGraphNodeFGNodePoolGraphPred, ResGatedGraphNodeNoFGNodeGraphPred, + ResGatedGraphNodeOnlyPoolGraphPred, ResGatedGraphNodePoolGraphPred, ) from .gat import GATGraphPred @@ -18,4 +21,7 @@ "ResGatedGraphNodePoolGraphPred", "ResGatedGraphNodeNoFGNodeGraphPred", "ResGatedFGNodePoolGraphPred", + "ResGatedAugOnlyPoolGraphPred", + "ResGatedGraphNodeOnlyPoolGraphPred", + "ResGatedFGOnlyPoolGraphPred", ] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index a7acd91..35f72aa 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -1,9 +1,12 @@ from .base import ( AugmentedNodePoolingNet, + AugmentedOnlyPoolingNet, FGNodePoolingNet, FGNodePoolingNoGraphNodeNet, + FGOnlyPoolingNet, GraphNodeFGNodePoolingNet, GraphNodeNoFGNodePoolingNet, + GraphNodeOnlyPoolingNet, GraphNodePoolingNet, ) from .resgated import ResGatedGraphPred @@ -49,3 +52,21 @@ class ResGatedFGNodeNoGraphNodeGraphPred( """GNN for graph-level prediction for augmented graphs without FG nodes""" NAME = "ResGatedFGNodeNoGraphNodeGraphPred" + + +class ResGatedAugOnlyPoolGraphPred(AugmentedOnlyPoolingNet, ResGatedGraphPred): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedAugOnlyPoolGraphPred" + + +class ResGatedGraphNodeOnlyPoolGraphPred(GraphNodeOnlyPoolingNet, ResGatedGraphPred): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedGraphNodeOnlyPoolGraphPred" + + +class ResGatedFGOnlyPoolGraphPred(FGOnlyPoolingNet, ResGatedGraphPred): + """GNN for graph-level prediction for augmented graphs""" + + NAME = "ResGatedFGOnlyPoolGraphPred" From adc55bcf2a37b49d0d9162f766792b42de244c71 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 4 Jul 2025 16:40:28 +0200 Subject: [PATCH 150/224] fg node only fg nodes with edges only readers --- chebai_graph/preprocessing/reader/__init__.py | 4 +- .../preprocessing/reader/augmented_reader.py | 397 +++++++++++------- .../utils/visualize_augmented_molecule.py | 4 +- 3 files changed, 241 insertions(+), 164 deletions(-) diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index 09946f0..3569e01 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -1,8 +1,8 @@ -from .augmented_reader import GraphFGAugmentorReader +from .augmented_reader import AtomFGWithFGEdgesAndGraphNodeReader from .reader import GraphPropertyReader, GraphReader __all__ = [ "GraphReader", "GraphPropertyReader", - "GraphFGAugmentorReader", + "AtomFGWithFGEdgesAndGraphNodeReader", ] diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index e6ffa32..b917220 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,7 +1,7 @@ import re import sys from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from chebai.preprocessing.reader import DataReader @@ -13,6 +13,15 @@ from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties import constants as k +assert sys.version_info >= ( + 3, + 7, +), "This code requires Python 3.7 or higher." +# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order +# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights +# https://mail.python.org/pipermail/python-dev/2017-December/151283.html +# Order preservation is necessary to to create `is_atom_node` mask + class _AugmentorReader(DataReader, ABC): """ @@ -37,11 +46,10 @@ def __init__(self, *args, **kwargs): # Record number of failure during augmented graph construction self.f_cnt_for_aug_graph = 0 self.mol_object_buffer = {} - self._num_of_nodes = 0 - self._num_of_edges = 0 + self._idx_of_node = 0 + self._idx_of_edge = 0 @classmethod - @abstractmethod def name(cls) -> str: """ Returns the name of the augmentor. @@ -49,9 +57,10 @@ def name(cls) -> str: Returns: str: Name of the augmentor. """ + return f"{cls.__name__}".lower() @abstractmethod - def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, Dict]: + def _create_augmented_graph(self, mol: Chem.Mol) -> tuple[torch.Tensor, dict]: """ Augments a molecule represented by a SMILES string. @@ -61,94 +70,7 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, Dict]: Returns: Tuple[torch.Tensor, Dict]: Graph edge index and augmented molecule information """ - - @abstractmethod - def _read_data(self, raw_data: str) -> GeomData: - """ - Reads raw data and returns a list of processed data. - - Args: - raw_data (str): Raw data input. - - Returns: - GeomData: `torch_geometric.data.Data` object. - """ - - def _smiles_to_mol(self, smiles: str) -> Chem.Mol: - """ - Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. - - Args: - smiles (str): SMILES string representing the molecule. - - Returns: - Chem.Mol: RDKit molecule object. - """ - mol = Chem.MolFromSmiles(smiles) - if mol is None: - print(f"RDKit failed to parse {smiles} (returned None)") - self.f_cnt_for_smiles += 1 - else: - try: - Chem.SanitizeMol(mol) - except Exception as e: - print(f"RDKit failed at sanitizing {smiles}, Error {e}") - self.f_cnt_for_smiles += 1 - return mol - - def on_finish(self) -> None: - """ - Finalizes the reading process and logs the number of failed SMILES and failed augmentation. - """ - print(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") - print( - f"Failed to construct augmented graph for {self.f_cnt_for_aug_graph} number of SMILES" - ) - self.mol_object_buffer = {} - - def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: - """ - Reads a specific property from a molecule represented by a SMILES string. - - Args: - smiles (str): SMILES string representing the molecule. - property (MolecularProperty): Molecular property object for which the value needs to be extracted. - - Returns: - Optional[List]: Property values if molecule parsing is successful, else None. - """ - if smiles in self.mol_object_buffer: - return property.get_property_value(self.mol_object_buffer[smiles]) - - mol = self._smiles_to_mol(smiles) - if mol is None: - return None - - returned_result = self._create_augmented_graph(mol) - if returned_result is None: - return None - - _, augmented_mol = returned_result - return property.get_property_value(augmented_mol) - - -class GraphFGAugmentorReader(_AugmentorReader): - """ - A reader class that augments molecules with artificial functional group (FG) nodes and a graph-level node - to support graph-based molecular learning tasks. - - The FG nodes to connected to its related atoms and graph node is connected to all FG nodes. - """ - - @classmethod - def name(cls) -> str: - """ - Returns the name identifier of the augmentor. - - Returns: - str: Name identifier. - """ - return "graph_fg_augmentor" + pass def _read_data(self, smiles: str) -> GeomData | None: """ @@ -205,18 +127,37 @@ def _read_data(self, smiles: str) -> GeomData | None: is_atom_mask = torch.zeros(NUM_NODES, dtype=torch.bool) NUM_ATOM_NODES = augmented_molecule["nodes"]["atom_nodes"].GetNumAtoms() is_atom_mask[:NUM_ATOM_NODES] = True - is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) - is_graph_node[-1] = True return GeomData( x=x, edge_index=edge_index, edge_attr=edge_attr, is_atom_node=is_atom_mask, - is_graph_node=is_graph_node, ) - def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: + def _smiles_to_mol(self, smiles: str) -> Chem.Mol: + """ + Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. + + Args: + smiles (str): SMILES string representing the molecule. + + Returns: + Chem.Mol: RDKit molecule object. + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"RDKit failed to parse {smiles} (returned None)") + self.f_cnt_for_smiles += 1 + else: + try: + Chem.SanitizeMol(mol) + except Exception as e: + print(f"RDKit failed at sanitizing {smiles}, Error {e}") + self.f_cnt_for_smiles += 1 + return mol + + def _create_augmented_graph(self, mol: Chem.Mol) -> tuple[torch.Tensor, dict]: """ Generates an augmented graph from a SMILES string. @@ -228,14 +169,88 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> Tuple[torch.Tensor, dict]: - Augmented graph edge index, - Augmented graph (nodes and edges). """ - edge_index, node_info, edge_info = self._augment_graph_structure(mol) - augmented_molecule = {"nodes": node_info, "edges": edge_info} + augmented_mol = self._augment_graph_structure(mol) + + directed_edge_index = augmented_mol["directed_edge_index"] + if directed_edge_index is None or directed_edge_index.shape[0] != 2: + raise ValueError( + f"Expected directed_edge_index to have shape [2, num_edges], but got shape {directed_edge_index.shape}" + ) + + # First all directed edges from source to target are placed, then all directed edges from target to source + # are placed --- this is needed as it is easier to align the property values in same way + undirected_edge_index = torch.cat( + [ + directed_edge_index, + directed_edge_index[[1, 0], :], + ], + dim=1, + ) + + augmented_mol["edge_info"][k.NUM_EDGES] *= 2 # Undirected edges + augmented_molecule = { + "nodes": augmented_mol["node_info"], + "edges": augmented_mol["edge_info"], + } + + return undirected_edge_index, augmented_molecule + + @abstractmethod + def _augment_graph_structure(self, mol: Chem.Mol) -> dict: + """ + Constructs the full augmented graph structure from a molecule. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + dict: A dictionary containing: + - Augmented graph edge index, + - Augmented graph node attributes + - Augmented graph edge attributes. + """ + pass + + def on_finish(self) -> None: + """ + Finalizes the reading process and logs the number of failed SMILES and failed augmentation. + """ + print(f"Failed to read {self.f_cnt_for_smiles} SMILES in total") + print( + f"Failed to construct augmented graph for {self.f_cnt_for_aug_graph} number of SMILES" + ) + self.mol_object_buffer = {} + + def read_property(self, smiles: str, property: MolecularProperty) -> Optional[list]: + """ + Reads a specific property from a molecule represented by a SMILES string. + + Args: + smiles (str): SMILES string representing the molecule. + property (MolecularProperty): Molecular property object for which the value needs to be extracted. + + Returns: + Optional[List]: Property values if molecule parsing is successful, else None. + """ + if smiles in self.mol_object_buffer: + return property.get_property_value(self.mol_object_buffer[smiles]) + + mol = self._smiles_to_mol(smiles) + if mol is None: + return None + + returned_result = self._create_augmented_graph(mol) + if returned_result is None: + return None + + _, augmented_mol = returned_result + return property.get_property_value(augmented_mol) - return edge_index, augmented_molecule +class AtomsFGReader(_AugmentorReader): def _augment_graph_structure( self, mol: Chem.Mol - ) -> Tuple[torch.Tensor, dict, dict]: + ) -> tuple[torch.Tensor, dict, dict]: """ Constructs the full augmented graph structure from a molecule. @@ -248,79 +263,55 @@ def _augment_graph_structure( - Augmented graph node attributes - Augmented graph edge attributes. """ - self._num_of_nodes = mol.GetNumAtoms() - self._num_of_edges = mol.GetNumBonds() + self._idx_of_node = mol.GetNumAtoms() + self._idx_of_edge = mol.GetNumBonds() self._annotate_atoms_and_bonds(mol) atom_edge_index = self._generate_atom_level_edge_index(mol) # Create FG-level structure and edges - fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds = ( + fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, fg_bonds = ( self._construct_fg_to_atom_structure(mol) ) - fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( - fg_to_atoms_map, bonds - ) - - fg_graph_edge_index, graph_node, fg_to_graph_edges = ( - self._construct_fg_to_graph_node_structure(fg_to_atoms_map) - ) - # Merge all edge types directed_edge_index = torch.cat( [ atom_edge_index, torch.tensor(fg_atom_edge_index, dtype=torch.long), - torch.tensor(fg_internal_edge_index, dtype=torch.long), - torch.tensor(fg_graph_edge_index, dtype=torch.long), ], dim=1, ) - # First all directed edges from source to target are placed, then all directed edges from target to source - # are placed --- this is needed as it is easier to align the property values in same way - undirected_edge_index = torch.cat( - [directed_edge_index, directed_edge_index[[1, 0], :]], dim=1 - ) - total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes), 1]) + total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)]) assert ( - self._num_of_nodes == total_atoms - ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._num_of_nodes}" - assert sys.version_info >= ( - 3, - 7, - ), "This code requires Python 3.7 or higher." - # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order - # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights - # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - # Order preservation is necessary to to create `is_atom_node` mask + self._idx_of_node == total_atoms + ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + node_info = { "atom_nodes": mol, "fg_nodes": fg_nodes, - "graph_node": graph_node, - "num_nodes": self._num_of_nodes, + "num_nodes": self._idx_of_node, } - total_edges = sum( - [ - mol.GetNumBonds(), - len(atom_fg_edges), - len(internal_fg_edges), - len(fg_to_graph_edges), - ] - ) + total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)]) assert ( - self._num_of_edges == total_edges - ), f"Mismatch in number of edges: expected {total_edges}, got {self._num_of_edges}" + self._idx_of_edge == total_edges + ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" edge_info = { k.WITHIN_ATOMS_EDGE: mol, k.ATOM_FG_EDGE: atom_fg_edges, - k.WITHIN_FG_EDGE: internal_fg_edges, - k.FG_GRAPHNODE_EDGE: fg_to_graph_edges, - k.NUM_EDGES: self._num_of_edges * 2, # Undirected edges + k.NUM_EDGES: self._idx_of_edge, + } + return { + "directed_edge_index": directed_edge_index, + "node_info": node_info, + "edge_info": edge_info, + "graph_meta_info": { + "fg_to_atoms_map": fg_to_atoms_map, + "fg_bonds": fg_bonds, + }, } - return undirected_edge_index, node_info, edge_info @staticmethod def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: @@ -387,7 +378,7 @@ def _construct_fg_to_atom_structure( molecule_atoms_set = set() for fg_smiles, fg_group in structure.items(): - fg_to_atoms_map[self._num_of_nodes] = fg_group + fg_to_atoms_map[self._idx_of_node] = fg_group is_ring_fg = fg_group["is_ring_fg"] connected_atoms = [] @@ -400,12 +391,12 @@ def _construct_fg_to_atom_structure( ) molecule_atoms_set.add(atom_idx) - fg_atom_edge_index[0].append(self._num_of_nodes) + fg_atom_edge_index[0].append(self._idx_of_node) fg_atom_edge_index[1].append(atom_idx) - atom_fg_edges[f"{self._num_of_nodes}_{atom_idx}"] = { + atom_fg_edges[f"{self._idx_of_node}_{atom_idx}"] = { k.EDGE_LEVEL: k.ATOM_FG_EDGE } - self._num_of_edges += 1 + self._idx_of_edge += 1 atom = mol.GetAtomWithIdx(atom_idx) connected_atoms.append(atom) @@ -415,14 +406,14 @@ def _construct_fg_to_atom_structure( else: self._set_fg_prop(connected_atoms, fg_nodes, fg_smiles) - self._num_of_nodes += 1 + self._idx_of_node += 1 return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds def _set_ring_fg_prop(self, connected_atoms, fg_nodes): # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = len(connected_atoms) - fg_nodes[self._num_of_nodes] = { + fg_nodes[self._idx_of_node] = { k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": f"RING_{ring_size}", "RING": ring_size, @@ -479,16 +470,50 @@ def _set_fg_prop(self, connected_atoms, fg_nodes, fg_smiles): if representative_atom is None: raise AssertionError("Expected at least one atom with a functional group.") - fg_nodes[self._num_of_nodes] = { + fg_nodes[self._idx_of_node] = { k.NODE_LEVEL: k.FG_NODE_LEVEL, "FG": representative_atom.GetProp("FG"), "RING": 0, "is_alkyl": is_alkyl, } + +class AtomFGWithFGEdgesReader(AtomsFGReader): + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + augmented_struct = super()._augment_graph_structure(mol) + graph_meta_info = augmented_struct["graph_meta_info"] + + fg_to_atoms_map = graph_meta_info["fg_to_atoms_map"] + fg_bonds = graph_meta_info["fg_bonds"] + + fg_internal_edge_index, internal_fg_edges = self._construct_fg_level_structure( + fg_to_atoms_map, fg_bonds + ) + + augmented_struct["edge_info"][k.WITHIN_FG_EDGE] = internal_fg_edges + augmented_struct["edge_info"][k.NUM_EDGES] += len(internal_fg_edges) + + assert ( + self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] + ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + assert ( + self._idx_of_node == augmented_struct["node_info"]["num_nodes"] + ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + + augmented_struct["directed_edge_index"] = torch.cat( + [ + augmented_struct["directed_edge_index"], + torch.tensor(fg_internal_edge_index, dtype=torch.long), + ], + dim=1, + ) + return augmented_struct + def _construct_fg_level_structure( self, fg_to_atoms_map: dict, bonds: list - ) -> Tuple[List[List[int]], dict]: + ) -> tuple[list[list[int]], dict]: """ Constructs internal edges between functional group nodes based on bond connections. @@ -518,7 +543,7 @@ def add_fg_internal_edge(source_fg, target_fg): internal_edge_index[0].append(source_fg) internal_edge_index[1].append(target_fg) internal_fg_edges[edge_str] = {k.EDGE_LEVEL: k.WITHIN_FG_EDGE} - self._num_of_edges += 1 + self._idx_of_edge += 1 for bond in bonds: source_atom, target_atom = bond[:2] @@ -549,9 +574,53 @@ def add_fg_internal_edge(source_fg, target_fg): return internal_edge_index, internal_fg_edges + +class AtomFGWithFGEdgesAndGraphNodeReader(AtomFGWithFGEdgesReader): + def _read_data(self, smiles): + geom_data = super()._read_data(smiles) + if geom_data is None: + return None + NUM_NODES = geom_data.x.shape[0] + is_graph_node = torch.zeros(NUM_NODES, dtype=torch.bool) + is_graph_node[-1] = True + geom_data.is_graph_node = is_graph_node + return geom_data + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + augmented_struct = super()._augment_graph_structure(mol) + graph_meta_info = augmented_struct["graph_meta_info"] + fg_to_atoms_map = graph_meta_info["fg_to_atoms_map"] + + fg_graph_edge_index, graph_node, fg_to_graph_edges = ( + self._construct_fg_to_graph_node_structure(fg_to_atoms_map) + ) + + augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = fg_to_graph_edges + augmented_struct["edge_info"][k.NUM_EDGES] += len(fg_to_graph_edges) + assert ( + self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] + ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" + + augmented_struct["node_info"]["graph_node"] = graph_node + augmented_struct["node_info"]["num_nodes"] += 1 + assert ( + self._idx_of_node == augmented_struct["node_info"]["num_nodes"] + ), f"Mismatch in number of nodes: expected {self._idx_of_node}, got {augmented_struct['node_info']['num_nodes']}" + + augmented_struct["directed_edge_index"] = torch.cat( + [ + augmented_struct["directed_edge_index"], + torch.tensor(fg_graph_edge_index, dtype=torch.long), + ], + dim=1, + ) + return augmented_struct + def _construct_fg_to_graph_node_structure( self, fg_to_atoms_map: dict - ) -> Tuple[List[List[int]], dict, dict]: + ) -> tuple[list[list[int]], dict, dict]: """ Constructs edges between functional group nodes and a global graph-level node. @@ -570,12 +639,20 @@ def _construct_fg_to_graph_node_structure( graph_edge_index = [[], []] for fg_id in fg_to_atoms_map: - graph_edge_index[0].append(self._num_of_nodes) + graph_edge_index[0].append(self._idx_of_node) graph_edge_index[1].append(fg_id) - fg_graph_edges[f"{self._num_of_nodes}_{fg_id}"] = { + fg_graph_edges[f"{self._idx_of_node}_{fg_id}"] = { k.EDGE_LEVEL: k.FG_GRAPHNODE_EDGE } - self._num_of_edges += 1 - self._num_of_nodes += 1 + self._idx_of_edge += 1 + self._idx_of_node += 1 return graph_edge_index, graph_node, fg_graph_edges + + +class AtomFGWithNoFGEdgesWithGraphNodeReader(_AugmentorReader): + pass + + +class AtomWithGraphNodeOnlyReader(_AugmentorReader): + pass diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index d16f2dd..808db89 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -11,7 +11,7 @@ from torch import Tensor from chebai_graph.preprocessing.properties import constants as k -from chebai_graph.preprocessing.reader import GraphFGAugmentorReader +from chebai_graph.preprocessing.reader import AtomFGWithFGEdgesAndGraphNodeReader matplotlib.use("TkAgg") @@ -414,7 +414,7 @@ class Main: """ def __init__(self): - self._fg_reader = GraphFGAugmentorReader() + self._fg_reader = AtomFGWithFGEdgesAndGraphNodeReader() def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> None: """ From 7d7b4bc16134ad5c02db33db1ccb8048ec344eb5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 4 Jul 2025 20:41:43 +0200 Subject: [PATCH 151/224] augmented readers for ablation studies --- chebai_graph/preprocessing/reader/__init__.py | 14 +- .../preprocessing/reader/augmented_reader.py | 135 +++++++++++------- .../utils/visualize_augmented_molecule.py | 4 +- 3 files changed, 101 insertions(+), 52 deletions(-) diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index 3569e01..12df70b 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -1,8 +1,18 @@ -from .augmented_reader import AtomFGWithFGEdgesAndGraphNodeReader +from .augmented_reader import ( + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, +) from .reader import GraphPropertyReader, GraphReader __all__ = [ "GraphReader", "GraphPropertyReader", - "AtomFGWithFGEdgesAndGraphNodeReader", + "AtomReader_WithGraphNodeOnly", + "AtomsFGReader_NoFGEdges_NoGraphNode", + "AtomFGReader_NoFGEdges_WithGraphNode", + "AtomFGReader_WithFGEdges_NoGraphNode", + "AtomFGReader_WithFGEdges_WithGraphNode", ] diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index b917220..91a6844 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,7 +1,7 @@ import re import sys from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional import torch from chebai.preprocessing.reader import DataReader @@ -195,7 +195,6 @@ def _create_augmented_graph(self, mol: Chem.Mol) -> tuple[torch.Tensor, dict]: return undirected_edge_index, augmented_molecule - @abstractmethod def _augment_graph_structure(self, mol: Chem.Mol) -> dict: """ Constructs the full augmented graph structure from a molecule. @@ -209,7 +208,36 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: - Augmented graph node attributes - Augmented graph edge attributes. """ - pass + self._idx_of_node = mol.GetNumAtoms() + self._idx_of_edge = mol.GetNumBonds() + + self._annotate_atoms_and_bonds(mol) + atom_edge_index = self._generate_atom_level_edge_index(mol) + + total_atoms = mol.GetNumAtoms() + assert ( + self._idx_of_node == total_atoms + ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" + + node_info = { + "atom_nodes": mol, + "num_nodes": self._idx_of_node, + } + + total_edges = mol.GetNumBonds() + assert ( + self._idx_of_edge == total_edges + ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" + edge_info = { + k.WITHIN_ATOMS_EDGE: mol, + } + + return { + "directed_edge_index": atom_edge_index, + "node_info": node_info, + "edge_info": edge_info, + "graph_meta_info": {}, + } def on_finish(self) -> None: """ @@ -247,7 +275,7 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[li return property.get_property_value(augmented_mol) -class AtomsFGReader(_AugmentorReader): +class AtomsFGReader_NoFGEdges_NoGraphNode(_AugmentorReader): def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: @@ -263,11 +291,8 @@ def _augment_graph_structure( - Augmented graph node attributes - Augmented graph edge attributes. """ - self._idx_of_node = mol.GetNumAtoms() - self._idx_of_edge = mol.GetNumBonds() - - self._annotate_atoms_and_bonds(mol) - atom_edge_index = self._generate_atom_level_edge_index(mol) + augmented_mol = super()._augment_graph_structure(mol) + atom_edge_index = augmented_mol["directed_edge_index"] # Create FG-level structure and edges fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, fg_bonds = ( @@ -282,36 +307,26 @@ def _augment_graph_structure( ], dim=1, ) + augmented_mol["directed_edge_index"] = directed_edge_index total_atoms = sum([mol.GetNumAtoms(), len(fg_nodes)]) assert ( self._idx_of_node == total_atoms ), f"Mismatch in number of nodes: expected {total_atoms}, got {self._idx_of_node}" - - node_info = { - "atom_nodes": mol, - "fg_nodes": fg_nodes, - "num_nodes": self._idx_of_node, - } + augmented_mol["node_info"]["fg_nodes"] = fg_nodes + augmented_mol["node_info"]["num_nodes"] = self._idx_of_node total_edges = sum([mol.GetNumBonds(), len(atom_fg_edges)]) assert ( self._idx_of_edge == total_edges ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" - edge_info = { - k.WITHIN_ATOMS_EDGE: mol, - k.ATOM_FG_EDGE: atom_fg_edges, - k.NUM_EDGES: self._idx_of_edge, - } - return { - "directed_edge_index": directed_edge_index, - "node_info": node_info, - "edge_info": edge_info, - "graph_meta_info": { - "fg_to_atoms_map": fg_to_atoms_map, - "fg_bonds": fg_bonds, - }, - } + augmented_mol["edge_info"][k.ATOM_FG_EDGE] = atom_fg_edges + augmented_mol["edge_info"][k.NUM_EDGES] = self._idx_of_edge + + augmented_mol["graph_meta_info"]["fg_to_atoms_map"] = fg_to_atoms_map + augmented_mol["graph_meta_info"]["fg_bonds"] = fg_bonds + + return augmented_mol @staticmethod def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: @@ -478,7 +493,7 @@ def _set_fg_prop(self, connected_atoms, fg_nodes, fg_smiles): } -class AtomFGWithFGEdgesReader(AtomsFGReader): +class AtomFGReader_WithFGEdges_NoGraphNode(AtomsFGReader_NoFGEdges_NoGraphNode): def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: @@ -575,7 +590,7 @@ def add_fg_internal_edge(source_fg, target_fg): return internal_edge_index, internal_fg_edges -class AtomFGWithFGEdgesAndGraphNodeReader(AtomFGWithFGEdgesReader): +class _AddGraphNode(_AugmentorReader): def _read_data(self, smiles): geom_data = super()._read_data(smiles) if geom_data is None: @@ -586,15 +601,13 @@ def _read_data(self, smiles): geom_data.is_graph_node = is_graph_node return geom_data - def _augment_graph_structure( - self, mol: Chem.Mol + def _add_graph_node_and_edges_to_nodes( + self, + augmented_struct: dict, + nodes_ids: dict[int, Any] | set[int], ) -> tuple[torch.Tensor, dict, dict]: - augmented_struct = super()._augment_graph_structure(mol) - graph_meta_info = augmented_struct["graph_meta_info"] - fg_to_atoms_map = graph_meta_info["fg_to_atoms_map"] - fg_graph_edge_index, graph_node, fg_to_graph_edges = ( - self._construct_fg_to_graph_node_structure(fg_to_atoms_map) + self._construct_nodes_to_graph_node_structure(nodes_ids) ) augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = fg_to_graph_edges @@ -618,8 +631,8 @@ def _augment_graph_structure( ) return augmented_struct - def _construct_fg_to_graph_node_structure( - self, fg_to_atoms_map: dict + def _construct_nodes_to_graph_node_structure( + self, nodes_ids: dict ) -> tuple[list[list[int]], dict, dict]: """ Constructs edges between functional group nodes and a global graph-level node. @@ -635,24 +648,50 @@ def _construct_fg_to_graph_node_structure( """ graph_node = {k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} - fg_graph_edges = {} + graph_to_nodes_edges = {} graph_edge_index = [[], []] - for fg_id in fg_to_atoms_map: + for fg_id in nodes_ids: graph_edge_index[0].append(self._idx_of_node) graph_edge_index[1].append(fg_id) - fg_graph_edges[f"{self._idx_of_node}_{fg_id}"] = { + graph_to_nodes_edges[f"{self._idx_of_node}_{fg_id}"] = { k.EDGE_LEVEL: k.FG_GRAPHNODE_EDGE } self._idx_of_edge += 1 self._idx_of_node += 1 - return graph_edge_index, graph_node, fg_graph_edges + return graph_edge_index, graph_node, graph_to_nodes_edges + + +class AtomFGReader_WithFGEdges_WithGraphNode( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + augmented_struct = super()._augment_graph_structure(mol) + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, fg_to_atoms_map + ) -class AtomFGWithNoFGEdgesWithGraphNodeReader(_AugmentorReader): - pass +class AtomFGReader_NoFGEdges_WithGraphNode( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + augmented_struct = super()._augment_graph_structure(mol) + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, fg_to_atoms_map + ) -class AtomWithGraphNodeOnlyReader(_AugmentorReader): - pass +class AtomReader_WithGraphNodeOnly(_AddGraphNode): + def _augment_graph_structure(self, mol): + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 808db89..21d5a13 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -11,7 +11,7 @@ from torch import Tensor from chebai_graph.preprocessing.properties import constants as k -from chebai_graph.preprocessing.reader import AtomFGWithFGEdgesAndGraphNodeReader +from chebai_graph.preprocessing.reader import AtomFGReader_WithFGEdges_WithGraphNode matplotlib.use("TkAgg") @@ -414,7 +414,7 @@ class Main: """ def __init__(self): - self._fg_reader = AtomFGWithFGEdgesAndGraphNodeReader() + self._fg_reader = AtomFGReader_WithFGEdges_WithGraphNode() def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> None: """ From bc605a821b66f962336b491906798ca311cc549a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 00:02:19 +0200 Subject: [PATCH 152/224] adapt visualization code for other aug readers --- .../preprocessing/reader/augmented_reader.py | 86 ++++------ .../utils/visualize_augmented_molecule.py | 158 ++++++++++++------ 2 files changed, 141 insertions(+), 103 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 91a6844..9817652 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,6 +1,6 @@ import re import sys -from abc import ABC, abstractmethod +from abc import ABC from typing import Any, Optional import torch @@ -59,19 +59,6 @@ def name(cls) -> str: """ return f"{cls.__name__}".lower() - @abstractmethod - def _create_augmented_graph(self, mol: Chem.Mol) -> tuple[torch.Tensor, dict]: - """ - Augments a molecule represented by a SMILES string. - - Args: - mol (Chem.Mol): RDKIT molecule. - - Returns: - Tuple[torch.Tensor, Dict]: Graph edge index and augmented molecule information - """ - pass - def _read_data(self, smiles: str) -> GeomData | None: """ Reads and augments molecular data from a SMILES string. @@ -230,6 +217,7 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: ), f"Mismatch in number of edges: expected {total_edges}, got {self._idx_of_edge}" edge_info = { k.WITHIN_ATOMS_EDGE: mol, + k.NUM_EDGES: self._idx_of_edge, } return { @@ -239,6 +227,37 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: "graph_meta_info": {}, } + @staticmethod + def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: + """ + Annotates each atom and bond with node and edge with certain properties. + + Args: + mol (Chem.Mol): RDKit molecule. + """ + for atom in mol.GetAtoms(): + atom.SetProp(k.NODE_LEVEL, k.ATOM_NODE_LEVEL) + for bond in mol.GetBonds(): + bond.SetProp(k.EDGE_LEVEL, k.WITHIN_ATOMS_EDGE) + + @staticmethod + def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: + """ + Generates bidirectional atom-level edge index tensor. + + Args: + mol (Chem.Mol): RDKit molecule. + + Returns: + torch.Tensor: Directed edge index tensor. + """ + # We need to ensure that directed edges which form a undirected edge are adjacent to each other + edge_index_list = [[], []] + for bond in mol.GetBonds(): + edge_index_list[0].append(bond.GetBeginAtomIdx()) + edge_index_list[1].append(bond.GetEndAtomIdx()) + return torch.tensor(edge_index_list, dtype=torch.long) + def on_finish(self) -> None: """ Finalizes the reading process and logs the number of failed SMILES and failed augmentation. @@ -328,37 +347,6 @@ def _augment_graph_structure( return augmented_mol - @staticmethod - def _annotate_atoms_and_bonds(mol: Chem.Mol) -> None: - """ - Annotates each atom and bond with node and edge with certain properties. - - Args: - mol (Chem.Mol): RDKit molecule. - """ - for atom in mol.GetAtoms(): - atom.SetProp(k.NODE_LEVEL, k.ATOM_NODE_LEVEL) - for bond in mol.GetBonds(): - bond.SetProp(k.EDGE_LEVEL, k.WITHIN_ATOMS_EDGE) - - @staticmethod - def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: - """ - Generates bidirectional atom-level edge index tensor. - - Args: - mol (Chem.Mol): RDKit molecule. - - Returns: - torch.Tensor: Directed edge index tensor. - """ - # We need to ensure that directed edges which form a undirected edge are adjacent to each other - edge_index_list = [[], []] - for bond in mol.GetBonds(): - edge_index_list[0].append(bond.GetBeginAtomIdx()) - edge_index_list[1].append(bond.GetEndAtomIdx()) - return torch.tensor(edge_index_list, dtype=torch.long) - def _construct_fg_to_atom_structure( self, mol: Chem.Mol ) -> tuple[list[list[int]], dict, dict, dict, list]: @@ -606,12 +594,12 @@ def _add_graph_node_and_edges_to_nodes( augmented_struct: dict, nodes_ids: dict[int, Any] | set[int], ) -> tuple[torch.Tensor, dict, dict]: - fg_graph_edge_index, graph_node, fg_to_graph_edges = ( + nodes_graph_edge_index, graph_node, nodes_to_graph_edges = ( self._construct_nodes_to_graph_node_structure(nodes_ids) ) - augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = fg_to_graph_edges - augmented_struct["edge_info"][k.NUM_EDGES] += len(fg_to_graph_edges) + augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = nodes_to_graph_edges + augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges) assert ( self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] ), f"Mismatch in number of edges: expected {self._idx_of_edge}, got {augmented_struct['edge_info'][k.NUM_EDGES]}" @@ -625,7 +613,7 @@ def _add_graph_node_and_edges_to_nodes( augmented_struct["directed_edge_index"] = torch.cat( [ augmented_struct["directed_edge_index"], - torch.tensor(fg_graph_edge_index, dtype=torch.long), + torch.tensor(nodes_graph_edge_index, dtype=torch.long), ], dim=1, ) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 21d5a13..06f5ba2 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -11,7 +11,13 @@ from torch import Tensor from chebai_graph.preprocessing.properties import constants as k -from chebai_graph.preprocessing.reader import AtomFGReader_WithFGEdges_WithGraphNode +from chebai_graph.preprocessing.reader import ( + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, +) matplotlib.use("TkAgg") @@ -39,6 +45,14 @@ BondType.DATIVEONE: "red", } +READER = { + "n_fge_w_gn": AtomFGReader_NoFGEdges_WithGraphNode, + "w_fge_w_gn": AtomFGReader_WithFGEdges_WithGraphNode, + "w_fge_n_gn": AtomFGReader_WithFGEdges_NoGraphNode, + "n_fge_n_gn": AtomsFGReader_NoFGEdges_NoGraphNode, + "atom_w_gn": AtomReader_WithGraphNodeOnly, +} + def _create_graph( edge_index: Tensor, augmented_graph_nodes: dict, augmented_graph_edges: dict @@ -68,33 +82,47 @@ def _create_graph( ) # Add functional group (FG) nodes - fg_nodes = augmented_graph_nodes["fg_nodes"] - for fg_idx, fg_props in fg_nodes.items(): + if "fg_nodes" in augmented_graph_nodes: + fg_nodes = augmented_graph_nodes["fg_nodes"] + for fg_idx, fg_props in fg_nodes.items(): + G.add_node( + fg_idx, + node_name=f"FG:{fg_props['FG']}", + node_type="fg", + node_color=NODE_COLOR_MAP["fg"], + ) + + if "graph_node" in augmented_graph_nodes: + graph_node_idx = augmented_graph_nodes["num_nodes"] - 1 G.add_node( - fg_idx, - node_name=f"FG:{fg_props['FG']}", - node_type="fg", - node_color=NODE_COLOR_MAP["fg"], + graph_node_idx, + node_name="Graph Node", + node_type="graph", + node_color=NODE_COLOR_MAP["graph"], ) - # Add special graph node - graph_node_idx = augmented_graph_nodes["num_nodes"] - 1 - G.add_node( - graph_node_idx, - node_name="Graph Node", - node_type="graph", - node_color=NODE_COLOR_MAP["graph"], - ) - # Decode edge types and add edges with proper color and type src_nodes, tgt_nodes = edge_index.tolist() with_atom_edges = { f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" for bond in augmented_graph_edges[k.WITHIN_ATOMS_EDGE].GetBonds() } - atom_fg_edges = set(augmented_graph_edges[k.ATOM_FG_EDGE]) - within_fg_edges = set(augmented_graph_edges[k.WITHIN_FG_EDGE]) - fg_graph_edges = set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE]) + + atom_fg_edges = ( + set(augmented_graph_edges[k.ATOM_FG_EDGE]) + if k.ATOM_FG_EDGE in augmented_graph_edges + else set() + ) + within_fg_edges = ( + set(augmented_graph_edges[k.WITHIN_FG_EDGE]) + if k.WITHIN_FG_EDGE in augmented_graph_edges + else set() + ) + fg_graph_edges = ( + set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE]) + if k.FG_GRAPHNODE_EDGE in augmented_graph_edges + else set() + ) for src, tgt in zip(src_nodes, tgt_nodes): undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} @@ -127,7 +155,7 @@ def _get_subgraph_by_node_type(G: nx.Graph, node_type: str) -> nx.Graph: selected_nodes = [ n for n, attr in G.nodes(data=True) if attr.get("node_type") == node_type ] - return G.subgraph(selected_nodes).copy() + return G.subgraph(selected_nodes).copy() if selected_nodes else None def _draw_hierarchy(G: nx.Graph, mol: Mol) -> None: @@ -235,32 +263,53 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: # Dictionary to store functional group node positions fg_pos = {} - - # Loop through each functional group node in the graph - for fg_node in _get_subgraph_by_node_type(G, "fg").nodes(): - # Get connected atom nodes (assuming edges are between fg and atom nodes) - connected_atoms = [ - nbr - for nbr in G.neighbors(fg_node) - if G.nodes[nbr].get("node_type") == "atom" - ] - - # Get the 2D positions of the connected atoms - positions = np.array([atom_pos[atom] for atom in connected_atoms]) - x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean() - fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation - - graph_node = next(iter(_get_subgraph_by_node_type(G, "graph").nodes())) - graph_pos_arr = np.array( - [ - fg_pos[nbr] - for nbr in G.neighbors(graph_node) - if G.nodes[nbr].get("node_type") == "fg" - ] - ) - graph_pos = { - graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4) - } + fg_subgraph = _get_subgraph_by_node_type(G, "fg") + if fg_subgraph: + # Loop through each functional group node in the graph + for fg_node in fg_subgraph.nodes(): + # Get connected atom nodes (assuming edges are between fg and atom nodes) + connected_atoms = [ + nbr + for nbr in G.neighbors(fg_node) + if G.nodes[nbr].get("node_type") == "atom" + ] + + # Get the 2D positions of the connected atoms + positions = np.array([atom_pos[atom] for atom in connected_atoms]) + x_mean, y_mean = positions[:, 0].mean(), positions[:, 1].mean() + fg_pos[fg_node] = (x_mean, y_mean, 2) # z = 2 for elevation + + graph_pos = {} + graph_subgraph = _get_subgraph_by_node_type(G, "graph") + if graph_subgraph: + graph_node = next(iter(graph_subgraph.nodes())) + neighbor_type = { + G.nodes[nbr].get("node_type") for nbr in G.neighbors(graph_node) + } + assert neighbor_type < { + "fg", + "atom", + }, f"Graph node {graph_node} must only connect to one type of node: {neighbor_type}" + + if "fg" in neighbor_type: + graph_pos_arr = np.array( + [ + fg_pos[nbr] + for nbr in G.neighbors(graph_node) + if G.nodes[nbr].get("node_type") == "fg" + ] + ) + else: + graph_pos_arr = np.array( + [ + atom_pos[nbr] + for nbr in G.neighbors(graph_node) + if G.nodes[nbr].get("node_type") == "atom" + ] + ) + graph_pos = { + graph_node: (graph_pos_arr[:, 0].mean(), graph_pos_arr[:, 1].mean(), 4) + } pos = {**atom_pos, **fg_pos, **graph_pos} @@ -413,10 +462,12 @@ class Main: Command-line wrapper class for plotting augmented molecular graphs. """ - def __init__(self): - self._fg_reader = AtomFGReader_WithFGEdges_WithGraphNode() - - def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> None: + @staticmethod + def plot( + smiles: str = "OC(=O)c1ccccc1O", + plot_type: str = "simple", + reader: str = "w_fge_w_gn", + ) -> None: """ Plot an augmented molecular graph from SMILES. @@ -427,7 +478,8 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No - h: Hierarchical 2D-graph with separate plane for each node type - 3d: Hierarchical 3D-graph """ - mol = self._fg_reader._smiles_to_mol(smiles) # noqa + fg_reader = READER[reader]() + mol = fg_reader._smiles_to_mol(smiles) if mol is None: raise ValueError(f"Invalid SMILES: {smiles}") @@ -435,9 +487,7 @@ def plot(self, smiles: str = "OC(=O)c1ccccc1O", plot_type: str = "simple") -> No plot_nonaugment_molecule_graph(mol) return - edge_index, augmented_molecule = self._fg_reader._create_augmented_graph( - mol - ) # noqa + edge_index, augmented_molecule = fg_reader._create_augmented_graph(mol) plot_augmented_graph(edge_index, augmented_molecule, mol, plot_type) From 7560fec5240e71114b256a5faa3ef85e45658551 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 12:30:02 +0200 Subject: [PATCH 153/224] add data classes for ablation readers --- .../preprocessing/datasets/__init__.py | 11 ++- chebai_graph/preprocessing/datasets/chebi.py | 69 +++++++++++++++---- configs/data/chebi50_augmented_baseline.yml | 2 +- 3 files changed, 68 insertions(+), 14 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index 43aed70..5f8ae1e 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -1,6 +1,10 @@ from .chebi import ( + ChEBI50_Atom_WGNOnly_GraphProp, + ChEBI50_NFGE_NGN_GraphProp, + ChEBI50_NFGE_WGN_GraphProp, + ChEBI50_WFGE_NGN_GraphProp, + ChEBI50_WFGE_WGN_GraphProp, ChEBI50GraphData, - ChEBI50GraphFGAugmentorReader, ChEBI50GraphProperties, ) from .pubchem import PubChemGraphProperties @@ -10,4 +14,9 @@ "ChEBI50GraphProperties", "ChEBI50GraphData", "PubChemGraphProperties", + "ChEBI50_Atom_WGNOnly_GraphProp", + "ChEBI50_NFGE_NGN_GraphProp", + "ChEBI50_NFGE_WGN_GraphProp", + "ChEBI50_WFGE_NGN_GraphProp", + "ChEBI50_WFGE_WGN_GraphProp", ] diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index ca07435..be513de 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -22,7 +22,11 @@ MolecularProperty, ) from chebai_graph.preprocessing.reader import ( - GraphFGAugmentorReader, + AtomFGReader_NoFGEdges_WithGraphNode, + AtomFGReader_WithFGEdges_NoGraphNode, + AtomFGReader_WithFGEdges_WithGraphNode, + AtomReader_WithGraphNodeOnly, + AtomsFGReader_NoFGEdges_NoGraphNode, GraphPropertyReader, GraphReader, ) @@ -178,20 +182,11 @@ def _merge_props_into_base(self, row): ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) - - is_atom_node = ( - geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None - ) - is_graph_node = ( - geom_data.is_graph_node if hasattr(geom_data, "is_graph_node") else None - ) return GeomData( x=x, edge_index=geom_data.edge_index, edge_attr=edge_attr, molecule_attr=molecule_attr, - is_atom_node=is_atom_node, - is_graph_node=is_graph_node, ) def load_processed_data_from_file(self, filename): @@ -249,5 +244,55 @@ class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): pass -class ChEBI50GraphFGAugmentorReader(GraphPropertiesMixIn, ChEBIOver50): - READER = GraphFGAugmentorReader +class AugGraphPropMixIn_NoGraphNode(GraphPropertiesMixIn, ABC): + READER = None + + def _merge_props_into_base(self, row): + data = super()._merge_props_into_base(row) + geom_data = row["features"] + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) + + is_atom_node = geom_data.is_atom_node + assert is_atom_node is not None, "is_atom_node must be set in the geom_data" + data.is_atom_node = is_atom_node + return data + + +class AugGraphPropMixIn_WithGraphNode(AugGraphPropMixIn_NoGraphNode, ABC): + READER = None + + def _merge_props_into_base(self, row): + data = super()._merge_props_into_base(row) + return self._add_graph_node_mask(data, row) + + def _add_graph_node_mask(self, data: GeomData, row) -> GeomData: + """ + Add a mask for graph nodes to the data. + This is used to distinguish between atom nodes and graph nodes. + """ + geom_data = row["features"] + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) + is_graph_node = geom_data.is_graph_node + assert is_graph_node is not None, "is_graph_node must be set in the geom_data" + data.is_graph_node = is_graph_node + return data + + +class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + READER = AtomFGReader_NoFGEdges_WithGraphNode + + +class ChEBI50_WFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + READER = AtomFGReader_WithFGEdges_NoGraphNode + + +class ChEBI50_NFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + READER = AtomsFGReader_NoFGEdges_NoGraphNode + + +class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + READER = AtomReader_WithGraphNodeOnly diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml index 7f9985c..cab8b56 100644 --- a/configs/data/chebi50_augmented_baseline.yml +++ b/configs/data/chebi50_augmented_baseline.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp init_args: properties: - chebai_graph.preprocessing.properties.AugAtomType From 416a39553350445ae3fe27d8aec92ae1ee3e2523 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 13:30:45 +0200 Subject: [PATCH 154/224] adapt aug props for ablation data classes --- .../properties/augmented_properties.py | 96 +++++++++++-------- .../preprocessing/properties/constants.py | 4 +- .../preprocessing/reader/augmented_reader.py | 4 +- .../utils/visualize_augmented_molecule.py | 10 +- 4 files changed, 66 insertions(+), 48 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index c174340..8739c6f 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,3 +1,4 @@ +import sys from abc import ABC from typing import Dict, List, Optional @@ -13,6 +14,15 @@ from . import properties as pr from .base import AtomProperty, BondProperty, FrozenPropertyAlias +# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order +# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights +# https://mail.python.org/pipermail/python-dev/2017-December/151283.html +assert sys.version_info >= ( + 3, + 7, +), "This code requires Python 3.7 or higher." +# Order preservation is necessary to to create `prop_list` + # --------------------- Atom Properties ----------------------------- class AugmentedAtomProperty(AtomProperty, ABC): @@ -24,9 +34,7 @@ def get_property_value(self, augmented_mol: Dict): f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" ) - missing_keys = {"atom_nodes", "fg_nodes", "graph_node"} - augmented_mol[ - self.MAIN_KEY - ].keys() + missing_keys = {"atom_nodes"} - augmented_mol[self.MAIN_KEY].keys() if missing_keys: raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") @@ -35,26 +43,29 @@ def get_property_value(self, augmented_mol: Dict): raise TypeError( f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' ) - prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] - fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] - graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] - if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict ' - f"containing its properties" - ) + if "fg_nodes" in augmented_mol[self.MAIN_KEY]: + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + if not isinstance(fg_nodes, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) + + if "graph_node" in augmented_mol[self.MAIN_KEY]: + graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] + if not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.append(self.get_atom_value(graph_node)) - # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order - # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights - # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) - prop_list.append(self.get_atom_value(graph_node)) assert ( len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] ), "Number of property values should be equal to number of nodes" - return prop_list def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): @@ -228,7 +239,7 @@ def get_property_value(self, augmented_mol: Dict) -> List: f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" ) - missing_keys = k.EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys() + missing_keys = {k.WITHIN_ATOMS_EDGE} - augmented_mol[self.MAIN_KEY].keys() if missing_keys: raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") @@ -237,31 +248,38 @@ def get_property_value(self, augmented_mol: Dict) -> List: raise TypeError( f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' ) - prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] - fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] - fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] - fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.FG_GRAPHNODE_EDGE] - - if ( - not isinstance(fg_atom_edges, dict) - or not isinstance(fg_edges, dict) - or not isinstance(fg_graph_node_edges, dict) - ): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["{k.ATOM_FG_EDGE}"]/["{k.WITHIN_FG_EDGE}"]/["{k.FG_GRAPHNODE_EDGE}"]) ' - f"must be an instance of dict containing its properties" + if k.ATOM_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] + if not isinstance(fg_atom_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.ATOM_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_atom_edges.values()] ) - # For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order - # https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights - # https://mail.python.org/pipermail/python-dev/2017-December/151283.html - prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()]) - prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) - prop_list.extend( - [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] - ) + if k.WITHIN_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] + if not isinstance(fg_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.WITHIN_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) + + if k.TO_GRAPHNODE_EDGE in augmented_mol[self.MAIN_KEY]: + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.TO_GRAPHNODE_EDGE] + if not isinstance(fg_graph_node_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.TO_GRAPHNODE_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] + ) num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 assert ( diff --git a/chebai_graph/preprocessing/properties/constants.py b/chebai_graph/preprocessing/properties/constants.py index f64e5cb..e73da46 100644 --- a/chebai_graph/preprocessing/properties/constants.py +++ b/chebai_graph/preprocessing/properties/constants.py @@ -8,6 +8,6 @@ WITHIN_ATOMS_EDGE = "within_atoms_lvl" WITHIN_FG_EDGE = "within_fg_lvl" ATOM_FG_EDGE = "atom_fg_lvl" -FG_GRAPHNODE_EDGE = "fg_graphNode_lvl" -EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_EDGE} +TO_GRAPHNODE_EDGE = "fg_graphNode_lvl" +EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, TO_GRAPHNODE_EDGE} NUM_EDGES = "num_undirected_edges" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 9817652..c11df8f 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -598,7 +598,7 @@ def _add_graph_node_and_edges_to_nodes( self._construct_nodes_to_graph_node_structure(nodes_ids) ) - augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = nodes_to_graph_edges + augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges) assert ( self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES] @@ -643,7 +643,7 @@ def _construct_nodes_to_graph_node_structure( graph_edge_index[0].append(self._idx_of_node) graph_edge_index[1].append(fg_id) graph_to_nodes_edges[f"{self._idx_of_node}_{fg_id}"] = { - k.EDGE_LEVEL: k.FG_GRAPHNODE_EDGE + k.EDGE_LEVEL: k.TO_GRAPHNODE_EDGE } self._idx_of_edge += 1 self._idx_of_node += 1 diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 06f5ba2..a75440d 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -25,7 +25,7 @@ k.WITHIN_ATOMS_EDGE: "#1f77b4", k.ATOM_FG_EDGE: "#9467bd", k.WITHIN_FG_EDGE: "#ff7f0e", - k.FG_GRAPHNODE_EDGE: "#2ca02c", + k.TO_GRAPHNODE_EDGE: "#2ca02c", } NODE_COLOR_MAP = { @@ -119,8 +119,8 @@ def _create_graph( else set() ) fg_graph_edges = ( - set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE]) - if k.FG_GRAPHNODE_EDGE in augmented_graph_edges + set(augmented_graph_edges[k.TO_GRAPHNODE_EDGE]) + if k.TO_GRAPHNODE_EDGE in augmented_graph_edges else set() ) @@ -133,7 +133,7 @@ def _create_graph( elif undirected_edge_set & within_fg_edges: edge_type = k.WITHIN_FG_EDGE elif undirected_edge_set & fg_graph_edges: - edge_type = k.FG_GRAPHNODE_EDGE + edge_type = k.TO_GRAPHNODE_EDGE else: raise ValueError("Unexpected edge type") G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) @@ -318,7 +318,7 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: k.WITHIN_ATOMS_EDGE: [], k.ATOM_FG_EDGE: [], k.WITHIN_FG_EDGE: [], - k.FG_GRAPHNODE_EDGE: [], + k.TO_GRAPHNODE_EDGE: [], } for src, tgt, data in G.edges(data=True): edge_type_to_edges[data["edge_type"]].append((src, tgt)) From cf33cdebfd65534e2c86d1be5299d3e3e4e1bb9e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 15:54:47 +0200 Subject: [PATCH 155/224] doc for chebi dataset class --- chebai_graph/preprocessing/datasets/chebi.py | 155 +++++++++++++------ chebai_graph/preprocessing/datasets/utils.py | 43 +++++ 2 files changed, 151 insertions(+), 47 deletions(-) create mode 100644 chebai_graph/preprocessing/datasets/utils.py diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index be513de..bf148d2 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,7 +1,6 @@ -import importlib import os from abc import ABC -from typing import Callable, List, Optional +from collections.abc import Callable import pandas as pd import torch @@ -15,7 +14,6 @@ from lightning_utilities.core.rank_zero import rank_zero_info from torch_geometric.data.data import Data as GeomData -import chebai_graph.preprocessing.properties as graph_properties from chebai_graph.preprocessing.properties import ( AtomProperty, BondProperty, @@ -31,42 +29,40 @@ GraphReader, ) +from .utils import resolve_property + class ChEBI50GraphData(ChEBIOver50): + """ChEBI dataset with at least 50 samples per class, using GraphReader.""" + READER = GraphReader def __init__(self, **kwargs): super().__init__(**kwargs) -def _resolve_property( - property, #: str | properties.MolecularProperty -) -> MolecularProperty: - # if property is given as a string, try to resolve as a class path - if isinstance(property, MolecularProperty): - return property - try: - # split class_path into module-part and class name - last_dot = property.rindex(".") - module_name = property[:last_dot] - class_name = property[last_dot + 1 :] - module = importlib.import_module(module_name) - return getattr(module, class_name)() - except ValueError: - # if only a class name is given, assume the module is chebai_graph.processing.properties - return getattr(graph_properties, property)() - - class GraphPropertiesMixIn(ChEBIOverX, ABC): + """Mixin for adding molecular property encodings to graph-based ChEBI datasets.""" + READER = GraphPropertyReader def __init__( - self, properties: Optional[List], transform: Optional[Callable] = None, **kwargs + self, + properties: list | None = None, + transform: Callable | None = None, + **kwargs, ): + """ + Initialize GraphPropertiesMixIn. + + Args: + properties: Optional list of MolecularProperty class paths or instances. + transform: Optional transformation applied to each data sample. + """ super().__init__(**kwargs) # atom_properties and bond_properties are given as lists containing class_paths if properties is not None: - properties = [_resolve_property(prop) for prop in properties] + properties = [resolve_property(prop) for prop in properties] properties = sorted( properties, key=lambda prop: self.get_property_path(prop) ) @@ -81,7 +77,13 @@ def __init__( ) self.transform = transform - def _setup_properties(self): + def _setup_properties(self) -> None: + """ + Process and cache molecular properties to disk. + + Returns: + None + """ raw_data = [] os.makedirs(self.processed_properties_dir, exist_ok=True) @@ -101,14 +103,16 @@ def _setup_properties(self): file, ) raw_data += list(self._load_dict(path)) + idents = [row["ident"] for row in raw_data] features = [row["features"] for row in raw_data] def enc_if_not_none(encode, value): - if value is not None and len(value) > 0: - return [encode(atom_v) for atom_v in value] - else: - return None + return ( + [encode(v) for v in value] + if value is not None and len(value) > 0 + else None + ) for property in self.properties: if not os.path.isfile(self.get_property_path(property)): @@ -137,30 +141,53 @@ def enc_if_not_none(encode, value): property.on_finish() @property - def processed_properties_dir(self): + def processed_properties_dir(self) -> str: return os.path.join(self.processed_dir, "properties") - def get_property_path(self, property: MolecularProperty): + def get_property_path(self, property: MolecularProperty) -> str: + """ + Construct the cache path for a given molecular property. + + Args: + property: Instance of a MolecularProperty. + + Returns: + Path to the cached property file. + """ return os.path.join( self.processed_properties_dir, f"{property.name}_{property.encoder.name}.pt", ) - def _after_setup(self, **kwargs): + def _after_setup(self, **kwargs) -> None: """ - Finalize the setup process after ensuring the processed data is available. + Finalize setup after ensuring properties are processed. + + Args: + **kwargs: Additional keyword arguments passed to superclass. - This method performs post-setup tasks like finalizing the reader and setting internal properties. + Returns: + None """ self._setup_properties() super()._after_setup(**kwargs) - def _merge_props_into_base(self, row): + def _merge_props_into_base(self, row: pd.Series) -> GeomData: + """ + Merge encoded molecular properties into the GeomData object. + + Args: + row: A dictionary containing 'features' and encoded properties. + + Returns: + A GeomData object with merged features. + """ geom_data = row["features"] assert isinstance(geom_data, GeomData) edge_attr = geom_data.edge_attr x = geom_data.x molecule_attr = torch.empty((1, 0)) + for property in self.properties: property_values = row[f"{property.name}"] if isinstance(property_values, torch.Tensor): @@ -172,6 +199,7 @@ def _merge_props_into_base(self, row): property_values = torch.zeros( (0, property.encoder.get_encoding_length()) ) + if isinstance(property, AtomProperty): x = torch.cat([x, property_values], dim=1) elif isinstance(property, BondProperty): @@ -182,6 +210,7 @@ def _merge_props_into_base(self, row): ) else: molecule_attr = torch.cat([molecule_attr, property_values], dim=1) + return GeomData( x=x, edge_index=geom_data.edge_index, @@ -189,9 +218,19 @@ def _merge_props_into_base(self, row): molecule_attr=molecule_attr, ) - def load_processed_data_from_file(self, filename): + def load_processed_data_from_file(self, filename: str) -> list[dict]: + """ + Load dataset and merge cached properties into base features. + + Args: + filename: The path to the file to load. + + Returns: + List of data entries, each a dictionary. + """ base_data = super().load_processed_data_from_file(filename) base_df = pd.DataFrame(base_data) + for property in self.properties: property_data = torch.load( self.get_property_path(property), weights_only=False @@ -218,36 +257,40 @@ def load_processed_data_from_file(self, filename): prop_lengths = [ (prop.name, prop.encoder.get_encoding_length()) for prop in self.properties ] - rank_zero_info( - f"Finished loading dataset from properties." - f"\nEncoding lengths are: " - f"{prop_lengths}" - f"\nIf you train a model with these properties and encodings, " - f"use n_atom_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if isinstance(prop, AtomProperty)])}, " - f"n_bond_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if isinstance(prop, BondProperty)])} " - f"and n_molecule_properties: {sum([prop.encoder.get_encoding_length() for prop in self.properties if not (isinstance(prop, AtomProperty) or isinstance(prop, BondProperty))])}" + f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" + f"Use n_atom_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, " + f"n_bond_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " + f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if not isinstance(p, (AtomProperty, BondProperty)))}" ) return base_df[base_data[0].keys()].to_dict("records") class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50): + """ChEBIOver50 dataset with molecular property encodings.""" + pass class ChEBI100GraphProperties(GraphPropertiesMixIn, ChEBIOver100): + """ChEBIOver100 dataset with molecular property encodings.""" + pass class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): + """Partial version of ChEBIOver50 with molecular properties.""" + pass class AugGraphPropMixIn_NoGraphNode(GraphPropertiesMixIn, ABC): + """Mixin for augmented graph data without additional graph nodes.""" + READER = None - def _merge_props_into_base(self, row): + def _merge_props_into_base(self, row: pd.Series) -> GeomData: data = super()._merge_props_into_base(row) geom_data = row["features"] assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) @@ -259,16 +302,24 @@ def _merge_props_into_base(self, row): class AugGraphPropMixIn_WithGraphNode(AugGraphPropMixIn_NoGraphNode, ABC): + """Mixin for augmented graph data with graph-level nodes.""" + READER = None - def _merge_props_into_base(self, row): + def _merge_props_into_base(self, row: pd.Series) -> GeomData: data = super()._merge_props_into_base(row) return self._add_graph_node_mask(data, row) def _add_graph_node_mask(self, data: GeomData, row) -> GeomData: """ - Add a mask for graph nodes to the data. - This is used to distinguish between atom nodes and graph nodes. + Add a graph node mask to the GeomData object. + + Args: + data: A GeomData object with features. + row: A dictionary containing 'features' and other metadata. + + Returns: + Modified GeomData with graph node mask added. """ geom_data = row["features"] assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) @@ -279,20 +330,30 @@ def _add_graph_node_mask(self, data: GeomData, row) -> GeomData: class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with with FG nodes and FG edges and graph node.""" + READER = AtomFGReader_WithFGEdges_WithGraphNode class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes but without FG edges, with graph node.""" + READER = AtomFGReader_NoFGEdges_WithGraphNode class ChEBI50_WFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes and FG edges, no graph node.""" + READER = AtomFGReader_WithFGEdges_NoGraphNode class ChEBI50_NFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): + """ChEBIOver50 with FG nodes but without FG edges or graph node.""" + READER = AtomsFGReader_NoFGEdges_NoGraphNode class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): + """ChEBIOver50 with atom-level nodes and graph node only.""" + READER = AtomReader_WithGraphNodeOnly diff --git a/chebai_graph/preprocessing/datasets/utils.py b/chebai_graph/preprocessing/datasets/utils.py new file mode 100644 index 0000000..3ec6515 --- /dev/null +++ b/chebai_graph/preprocessing/datasets/utils.py @@ -0,0 +1,43 @@ +import importlib + +import chebai_graph.preprocessing.properties as graph_properties +from chebai_graph.preprocessing.properties import MolecularProperty + + +def resolve_property(property: str | MolecularProperty) -> MolecularProperty: + """ + Resolves a molecular property specification (either as a class instance or class path string) + into a MolecularProperty instance. + + This utility is designed to support flexible specification of molecular properties + in dataset configurations. It handles: + - Direct instances of MolecularProperty + - Full class paths as strings (e.g., "my_module.MyProperty") + - Shorthand class names assumed to exist in chebai_graph.preprocessing.properties + + Args: + property (str | MolecularProperty): The property to resolve. Can be a class instance, + a fully qualified class name (e.g. "module.ClassName"), or a class name assumed + to be in `chebai_graph.preprocessing.properties`. + + Returns: + MolecularProperty: An instance of the resolved MolecularProperty. + + Raises: + AttributeError: If the class name cannot be found in the target module. + ModuleNotFoundError: If the module cannot be imported. + TypeError: If the resolved object is not a MolecularProperty instance. + """ + # if property is given as a string, try to resolve as a class path + if isinstance(property, MolecularProperty): + return property + try: + # split class_path into module-part and class name + last_dot = property.rindex(".") + module_name = property[:last_dot] + class_name = property[last_dot + 1 :] + module = importlib.import_module(module_name) + return getattr(module, class_name)() + except ValueError: + # if only a class name is given, assume the module is chebai_graph.processing.properties + return getattr(graph_properties, property)() From 7524905cf8857373688c383cc3eae4510678833f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 16:02:37 +0200 Subject: [PATCH 156/224] doc for props --- chebai_graph/preprocessing/properties/base.py | 163 ++++++++++--- .../preprocessing/properties/properties.py | 230 ++++++++++++++++-- 2 files changed, 331 insertions(+), 62 deletions(-) diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index 3614772..501a340 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -7,102 +7,193 @@ class MolecularProperty(ABC): - def __init__(self, encoder: PropertyEncoder | None = None): + """ + Abstract base class representing a molecular property. + + Properties can be atom-level, bond-level, or molecule-level. + Each property is associated with a PropertyEncoder that encodes + the raw property values into suitable feature representations. + + Args: + encoder: Optional encoder instance to encode property values. + Defaults to IndexEncoder if not provided. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: if encoder is None: encoder = IndexEncoder(self) - self.encoder = encoder + self.encoder: PropertyEncoder = encoder @property - def name(self): - """Unique identifier for this property.""" + def name(self) -> str: + """ + Unique identifier for this property, typically the class name. + + Returns: + The class name as the property's unique name. + """ return self.__class__.__name__ - def on_finish(self): - """Called after dataset processing is done.""" + def on_finish(self) -> None: + """ + Called after dataset processing is complete. + + Typically used to finalize encoder states, e.g., saving cache. + """ self.encoder.on_finish() - def __str__(self): + def __str__(self) -> str: + """ + String representation of the property. + + Returns: + The property name. + """ return self.name @abstractmethod - def get_property_value(self, mol: Chem.rdchem.Mol | dict): ... + def get_property_value(self, mol: Chem.rdchem.Mol | dict) -> list: + """ + Abstract method to extract the raw property value(s) from a molecule. + + Args: + mol: RDKit molecule object or a dictionary representation. + + Returns: + A list of raw property values for the molecule. + """ + ... class AtomProperty(MolecularProperty, ABC): - """Property of an atom.""" + """ + Abstract base class representing an atom-level molecular property. - def get_property_value(self, mol: Chem.rdchem.Mol): + Subclasses must implement get_atom_value to extract property per atom. + """ + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list: + """ + Extract the property value for each atom in the molecule. + + Args: + mol: RDKit molecule object. + + Returns: + List of property values, one per atom. + """ return [self.get_atom_value(atom) for atom in mol.GetAtoms()] @abstractmethod - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> object: + """ + Abstract method to extract the property value of a single atom. + + Args: + atom: RDKit atom object. + + Returns: + The property value for the atom. + """ pass class BondProperty(MolecularProperty, ABC): - def get_property_value(self, mol: Chem.rdchem.Mol): + """ + Abstract base class representing a bond-level molecular property. + + Subclasses must implement get_bond_value to extract property per bond. + """ + + def get_property_value(self, mol: Chem.rdchem.Mol) -> list: + """ + Extract the property value for each bond in the molecule. + + Args: + mol: RDKit molecule object. + + Returns: + List of property values, one per bond. + """ return [self.get_bond_value(bond) for bond in mol.GetBonds()] @abstractmethod - def get_bond_value(self, bond: Chem.rdchem.Bond): + def get_bond_value(self, bond: Chem.rdchem.Bond) -> object: + """ + Abstract method to extract the property value of a single bond. + + Args: + bond: RDKit bond object. + + Returns: + The property value for the bond. + """ pass class MoleculeProperty(MolecularProperty): - """Global property of a molecule.""" + """ + Class representing a global (molecule-level) property. + + Subclasses should override get_property_value for molecule-wide values. + """ + + pass class FrozenPropertyAlias(MolecularProperty, ABC): """ - Wrapper base class for augmented graph properties that want to reuse existing molecular properties. - - This class allows augmented graph property classes to inherit both from this wrapper and a standard - molecular property (from `.properties`), enabling reuse of their encoders and index files without - modifying them. + Wrapper base class for augmented graph properties that reuse existing molecular properties. - Key Features: - - Prevents new tokens from being added to the encoder cache by freezing it. - - Automatically aligns the property name (used for encoder/index resolution) with the inherited - base property by removing the "Aug" prefix from the class name. + This allows an augmented property class (with an 'Aug' prefix in its name) to: + - Reuse the encoder and index files of the base property by removing the 'Aug' prefix from its name. + - Prevent adding new tokens to the encoder cache by freezing it (using MappingProxyType). Usage: - The derived class should: - - Inherit from `FrozenPropertyAlias` **and** a valid base molecular property class. - - Have a name starting with "Aug" (e.g., `AugAtomType`), which will be resolved to `AtomType`. + Inherit from FrozenPropertyAlias and the desired base molecular property class, + and name the class with an 'Aug' prefix (e.g., 'AugAtomType'). Example: ```python class AugAtomType(FrozenPropertyAlias, AtomType): ... ``` - Note: - Subclass name of this class should with prefix "Aug" for above effect to take place. - This allows `AugAtomType` to reuse the encoder, index files, and logic of `AtomType` while - integrating into augmented graph pipelines. + Raises: + ValueError: If new tokens are added to the frozen encoder during processing. """ - def __init__(self, encoder: PropertyEncoder | None = None): + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder) # Lock the encoder's cache to prevent adding new tokens if hasattr(self.encoder, "cache") and isinstance(self.encoder.cache, dict): self.encoder.cache = MappingProxyType(self.encoder.cache) @property - def name(self): + def name(self) -> str: """ - Unique identifier for this property, with 'Aug' prefix removed if present. - This allows the encoder to reuse index files of the corresponding base property. + Unique identifier for this property. + + Returns: + The class name with the 'Aug' prefix removed if present, + allowing reuse of the base property encoder/index files. """ class_name = self.__class__.__name__ return class_name[3:] if class_name.startswith("Aug") else class_name - def on_finish(self): + def on_finish(self) -> None: + """ + Called after dataset processing. + + Ensures no new tokens were added to the frozen encoder cache. + Raises an error if this condition is violated. + """ if ( hasattr(self.encoder, "cache") and len(self.encoder.cache) > self.encoder.index_length_start ): raise ValueError( - f"{self.__class__.__name__} attempted to add new tokens to a {self.encoder.index_path}" + f"{self.__class__.__name__} attempted to add new tokens " + f"to a frozen encoder at {self.encoder.index_path}" ) super().on_finish() diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index 8e0a425..ccf869c 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -1,5 +1,3 @@ -from typing import Optional - import numpy as np import rdkit.Chem as Chem from descriptastorus.descriptors import rdNormalizedDescriptors @@ -15,98 +13,278 @@ class AtomType(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the atomic number (type) of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the atomic number of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Atomic number of the atom. + """ return atom.GetAtomicNum() class NumAtomBonds(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the number of bonds (degree) of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the number of bonds for the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Number of bonds (degree). + """ return atom.GetDegree() class AtomCharge(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the formal charge of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the formal charge of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Formal charge. + """ return atom.GetFormalCharge() class AtomChirality(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the chirality tag of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> Chem.rdchem.ChiralType: + """ + Get the chirality tag of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + Chem.rdchem.ChiralType: Chirality tag. + """ return atom.GetChiralTag() class AtomHybridization(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the hybridization state of an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> Chem.rdchem.HybridizationType: + """ + Get the hybridization state of the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + Chem.rdchem.HybridizationType: Hybridization state. + """ return atom.GetHybridization() class AtomNumHs(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing the total number of hydrogens bonded to an atom. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> int: + """ + Get the total number of hydrogens attached to the atom. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + int: Number of attached hydrogens. + """ return atom.GetTotalNumHs() class AtomAromaticity(AtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Atom property representing whether an atom is aromatic. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or BoolEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom): + def get_atom_value(self, atom: Chem.rdchem.Atom) -> bool: + """ + Check if the atom is aromatic. + + Args: + atom (Chem.rdchem.Atom): RDKit atom object. + + Returns: + bool: True if aromatic, else False. + """ return atom.GetIsAromatic() class BondAromaticity(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Bond property representing whether a bond is aromatic. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or BoolEncoder(self)) - def get_bond_value(self, bond: Chem.rdchem.Bond): + def get_bond_value(self, bond: Chem.rdchem.Bond) -> bool: + """ + Check if the bond is aromatic. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + bool: True if aromatic, else False. + """ return bond.GetIsAromatic() class BondType(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Bond property representing the bond type (single, double, etc.). + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_bond_value(self, bond: Chem.rdchem.Bond): + def get_bond_value(self, bond: Chem.rdchem.Bond) -> Chem.rdchem.BondType: + """ + Get the bond type. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + Chem.rdchem.BondType: Type of bond. + """ return bond.GetBondType() class BondInRing(BondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Bond property indicating whether a bond is in a ring. + + Uses a boolean encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or BoolEncoder(self)) - def get_bond_value(self, bond: Chem.rdchem.Bond): + def get_bond_value(self, bond: Chem.rdchem.Bond) -> bool: + """ + Check if the bond is part of a ring. + + Args: + bond (Chem.rdchem.Bond): RDKit bond object. + + Returns: + bool: True if in a ring, else False. + """ return bond.IsInRing() class MoleculeNumRings(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Molecule-level property representing the number of rings in the molecule. + + Uses a one-hot encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or OneHotEncoder(self)) - def get_property_value(self, mol: Chem.rdchem.Mol): + def get_property_value(self, mol: Chem.rdchem.Mol) -> list[int]: + """ + Get the number of rings in the molecule. + + Args: + mol (Chem.rdchem.Mol): RDKit molecule object. + + Returns: + list[int]: List with single integer representing number of rings. + """ return [mol.GetRingInfo().NumRings()] class RDKit2DNormalized(MolecularProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + """ + Molecule-level property representing normalized 2D descriptors from RDKit. + + Uses an identity encoder by default. + """ + + def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or AsIsEncoder(self)) - def get_property_value(self, mol: Chem.rdchem.Mol): + def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]: + """ + Compute normalized RDKit 2D descriptors for the molecule. + + Args: + mol (Chem.rdchem.Mol): RDKit molecule object. + + Returns: + list[np.ndarray]: List containing the descriptor numpy array (excluding first element). + """ generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() features_normalized = generator_normalized.processMol( mol, Chem.MolToSmiles(mol) From d11305ff7c69485b335ff1847dcbc9e5991becfd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 19:59:19 +0200 Subject: [PATCH 157/224] doc for aug props --- .../properties/augmented_properties.py | 416 +++++++++++++++--- 1 file changed, 353 insertions(+), 63 deletions(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 8739c6f..a0ce122 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,6 +1,5 @@ import sys from abc import ABC -from typing import Dict, List, Optional from rdkit import Chem @@ -28,7 +27,21 @@ class AugmentedAtomProperty(AtomProperty, ABC): MAIN_KEY = "nodes" - def get_property_value(self, augmented_mol: Dict): + def get_property_value(self, augmented_mol: dict) -> list: + """ + Extract property values for atoms from the augmented molecule dictionary. + + Args: + augmented_mol (dict): Dictionary representing the augmented molecule. + + Raises: + KeyError: If required keys are missing in the dictionary. + TypeError: If types of contained objects are incorrect. + AssertionError: If the number of property values does not match number of nodes. + + Returns: + list: List of property values for all atoms, functional groups, and graph nodes. + """ if self.MAIN_KEY not in augmented_mol: raise KeyError( f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" @@ -68,14 +81,44 @@ def get_property_value(self, augmented_mol: Dict): ), "Number of property values should be equal to number of nodes" return prop_list - def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + def _check_modify_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Check that the property value for the atom/node exists and is not empty. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node representation. + prop (str): Property name. + + Raises: + ValueError: If the property is empty. + + Returns: + str | int | bool: The property value. + """ value = self._get_atom_prop_value(atom, prop) if not value: # Every atom/node should have given value raise ValueError(f"'{prop}' is set but empty.") return value - def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + def _get_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Retrieve a property value from an atom or dict node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + prop (str): Property name. + + Raises: + TypeError: If atom is not an expected type. + + Returns: + str | int | bool: The property value. + """ if isinstance(atom, Chem.rdchem.Atom): return atom.GetProp(prop) elif isinstance(atom, dict): @@ -87,29 +130,86 @@ def _get_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): class AtomNodeLevel(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomNodeLevel with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: + """ + Get the node level property for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + str | int | bool: Property value. + """ return self._check_modify_atom_prop_value(atom, k.NODE_LEVEL) class AtomFunctionalGroup(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomFunctionalGroup with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: + """ + Get the functional group property for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + str | int | bool: Property value. + """ return self._check_modify_atom_prop_value(atom, "FG") class AtomRingSize(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize AtomRingSize with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to OneHotEncoder. + """ super().__init__(encoder or OneHotEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int: + """ + Get the ring size for a given atom/node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + int: Maximum ring size the atom belongs to, or 0 if none. + """ return self._check_modify_atom_prop_value(atom, "RING") - def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str): + def _check_modify_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> int: + """ + Override to parse and return maximum ring size from a property string. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + prop (str): Property name. + + Returns: + int: Maximum ring size or 0. + """ ring_size_str = self._get_atom_prop_value(atom, prop) if ring_size_str: ring_sizes = list(map(int, ring_size_str.split("-"))) @@ -120,7 +220,13 @@ def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str class IsHydrogenBondDonorFG(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize IsHydrogenBondDonorFG with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to BoolEncoder. + """ super().__init__(encoder or BoolEncoder(self)) # fmt: off # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L26-L31 @@ -132,13 +238,28 @@ def __init__(self, encoder: Optional[PropertyEncoder] = None): } # fmt: on - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: + """ + Check if the atom's functional group is a hydrogen bond donor. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + + Returns: + bool: True if hydrogen bond donor, else False. + """ fg = self._check_modify_atom_prop_value(atom, "FG") return fg in self._hydrogen_bond_donor class IsHydrogenBondAcceptorFG(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Initialize IsHydrogenBondAcceptorFG with an optional encoder. + + Args: + encoder (PropertyEncoder | None): Property encoder to use. Defaults to BoolEncoder. + """ super().__init__(encoder or BoolEncoder(self)) # fmt: off # https://github.com/thaonguyen217/farm_molecular_representation/blob/main/src/(6)gen_FG_KG.py#L33-L39 @@ -151,21 +272,56 @@ def __init__(self, encoder: Optional[PropertyEncoder] = None): } # fmt: on - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: + """ + Determine if the atom is a hydrogen bond acceptor. + + Args: + atom (Chem.rdchem.Atom | dict): The atom object or a dictionary of atom properties. + + Returns: + bool: True if the atom is a hydrogen bond acceptor, False otherwise. + """ fg = self._check_modify_atom_prop_value(atom, "FG") return fg in self._hydrogen_bond_acceptor class IsFGAlkyl(AugmentedAtomProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Args: + encoder (PropertyEncoder | None): Optional encoder to use for this property. + Defaults to BoolEncoder if not provided. + """ super().__init__(encoder or BoolEncoder(self)) - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int: + """ + Get the alkyl group status of the given atom. + + Args: + atom (Chem.rdchem.Atom | dict): Atom object or atom property dictionary. + + Returns: + int: 1 if alkyl, 0 otherwise. + """ return int(self._check_modify_atom_prop_value(atom, "is_alkyl")) class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): - def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): + def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int | None: + """ + Get the property value for an atom or dict node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom object or dict representing node properties. + + Returns: + int | None: Property value or None for dict nodes. + + Raises: + TypeError: If input is neither Chem.rdchem.Atom nor dict. + """ if isinstance(atom, Chem.rdchem.Atom): # Delegate to superclass method for atom return super().get_atom_value(atom) @@ -178,54 +334,95 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | Dict): class AugAtomType(AugNodeValueDefaulter, pr.AtomType): - # This property uses OneHotEncoder as default encoder - # TODO: Can we return 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Currently, we return None which leads to zero-tensor for augmented nodes - - # RDKit uses 0 as the atomic number for a "dummy atom", which usually means: - # A placeholder atom (e.g. [*], R#, or attachment points in SMARTS/SMILES). - # An undefined or wildcard atom. - # A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). + """ + This property uses OneHotEncoder as default encoder + + TODO: Can we return 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes + Currently, we return None which leads to zero-tensor for augmented nodes + + RDKit uses 0 as the atomic number for a "dummy atom", which usually means: + - A placeholder atom (e.g. [*], R#, or attachment points in SMARTS/SMILES). + - An undefined or wildcard atom. + - A pseudoatom (e.g., for certain fragments or placeholders in reaction centers). + """ + ... class AugNumAtomBonds(AugNodeValueDefaulter, pr.NumAtomBonds): - # This property uses OneHotEncoder as default encoder - # Default return value for this property can't be zero, 0 is used for isolated atoms in molecule. It has to be None or actual node degree. - # TODO: Can return actual node degree/num of connections for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Currently, we return None which leads to zero-tensor for augmented nodes - # But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. Consider graph node too. + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, 0 is used for isolated atoms in molecule. + It has to be None or actual node degree. + + TODO: Can return actual node degree/num of connections for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + + But then the question aries shall we count only the atoms connected to a fg node, or all nodes including atoms. + Consider graph node too. + """ + ... class AugAtomCharge(AugNodeValueDefaulter, pr.AtomCharge): - # This property uses OneHotEncoder as default encoder - # Default return value for this property can't be zero, as atoms can have 0 charge. - # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Currently, we return None which leads to zero-tensor for augmented nodes + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, as atoms can have 0 charge. + + TODO: Can return some `unk` value for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + ... class AugAtomHybridization(AugNodeValueDefaulter, pr.AtomHybridization): - # This property uses OneHotEncoder as default encoder - # TODO: Can return some `HybridizationType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType - # Currently, we return None which leads to zero-tensor for augmented nodes + """ + This property uses OneHotEncoder as default encoder + + TODO: Can return some `HybridizationType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.HybridizationType + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + ... class AugAtomNumHs(AugNodeValueDefaulter, pr.AtomNumHs): - # This property uses OneHotEncoder as default encoder - # Default return value for this property can't be zero, as atoms can have 0 Hydrogen atoms attached which mean atoms is full balanced by bonding with other non-hydrogen atoms. - # TODO: Can return some `unk` value for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Currently, we return None which leads to zero-tensor for augmented nodes + """ + This property uses OneHotEncoder as default encoder + + Default return value for this property can't be zero, as atoms can have 0 Hydrogen atoms attached + which mean atoms is full balanced by bonding with other non-hydrogen atoms. + + TODO: Can return some `unk` value for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + ... class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): - # This property uses BoolEncoder as default encoder - # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. - # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + ... @@ -233,7 +430,21 @@ class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): class AugmentedBondProperty(BondProperty, ABC): MAIN_KEY = "edges" - def get_property_value(self, augmented_mol: Dict) -> List: + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get bond property values from augmented molecule dict. + + Args: + augmented_mol (dict): Augmented molecule dictionary containing edges. + + Returns: + list: List of property values for bonds in the augmented molecule. + + Raises: + KeyError: If required keys are missing in augmented_mol. + TypeError: If the expected objects are not of correct types. + AssertionError: If number of property values does not match expected edge count. + """ if self.MAIN_KEY not in augmented_mol: raise KeyError( f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" @@ -288,7 +499,22 @@ def get_property_value(self, augmented_mol: Dict) -> List: return prop_list - def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str): + def _check_modify_bond_prop_value( + self, bond: Chem.rdchem.Bond | dict, prop: str + ) -> str: + """ + Helper to check and get bond property value. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or bond property dict. + prop (str): Property key to get. + + Returns: + str: Property value. + + Raises: + ValueError: If value is empty or falsy. + """ value = self._get_bond_prop_value(bond, prop) if not value: # Every atom/node should have given value @@ -296,7 +522,20 @@ def _check_modify_bond_prop_value(self, bond: Chem.rdchem.Bond | Dict, prop: str return value @staticmethod - def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): + def _get_bond_prop_value(bond: Chem.rdchem.Bond | dict, prop: str) -> str: + """ + Extract bond property value from bond or dict. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or dict. + prop (str): Property key. + + Returns: + str: Property value. + + Raises: + TypeError: If bond is not the expected type. + """ if isinstance(bond, Chem.rdchem.Bond): return bond.GetProp(prop) elif isinstance(bond, dict): @@ -306,15 +545,40 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | Dict, prop: str): class BondLevel(AugmentedBondProperty): - def __init__(self, encoder: Optional[PropertyEncoder] = None): + def __init__(self, encoder: PropertyEncoder | None = None): + """ + Args: + encoder (PropertyEncoder | None): Optional encoder to use. Defaults to OneHotEncoder. + """ super().__init__(encoder or OneHotEncoder(self)) - def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): + def get_bond_value(self, bond: Chem.rdchem.Bond | dict) -> str: + """ + Get the bond level property value. + + Args: + bond (Chem.rdchem.Bond | dict): Bond or bond dict. + + Returns: + str: Bond level property. + """ return self._check_modify_bond_prop_value(bond, k.EDGE_LEVEL) class AugBondValueDefaulter(AugmentedBondProperty, FrozenPropertyAlias, ABC): - def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): + def get_bond_value(self, bond: Chem.rdchem.Bond | dict) -> str | None: + """ + Get bond property value or None for dict bonds. + + Args: + bond (Chem.rdchem.Bond | dict): Bond or bond dict. + + Returns: + str | None: Property value or None for dict. + + Raises: + TypeError: If input type is invalid. + """ if isinstance(bond, Chem.rdchem.Bond): # Delegate to superclass method for bond return super().get_bond_value(bond) @@ -325,30 +589,56 @@ def get_bond_value(self, bond: Chem.rdchem.Bond | Dict): class AugBondAromaticity(AugBondValueDefaulter, pr.BondAromaticity): - # This property uses BoolEncoder as default encoder - # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. - # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + ... class AugBondType(AugBondValueDefaulter, pr.BondType): - # This property uses OneHotEncoder as default encoder - # TODO: Can return some `BondType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? which will lead to use of one hot tensor for augmented nodes - # Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType - # Currently, we return None which leads to zero-tensor for augmented nodes + """ + This property uses OneHotEncoder as default encoder + + TODO: Can return some `BondType.UNSPECIFIED` value which is 0 for augmented Nodes for this property? + which will lead to use of one hot tensor for augmented nodes + + Check: https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondType + + Currently, we return None which leads to zero-tensor for augmented nodes + """ + ... class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): - # This property uses BoolEncoder as default encoder - # Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. - # This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + This property uses BoolEncoder as default encoder + + Currently, we return None for augmented nodes which leads to BoolEncoder setting 0 internally. + + This is None is right value for augmented nodes its not part of any kind of aromatic ring. + """ + ... # --------------------- Molecular Properties ------------------------------ class AugmentedMolecularProperty(pr.MolecularProperty, ABC): - def get_property_value(self, augmented_mol: Dict) -> list: + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get molecular property values from augmented molecule dict. + + Args: + augmented_mol (dict): Augmented molecule dict. + + Returns: + list: Property values of molecule. + """ mol: Chem.Mol = augmented_mol[AugmentedAtomProperty.MAIN_KEY]["atom_nodes"] assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" return super().get_property_value(mol) From ab25622ac02f5ff710d5320a4e47b8384ae5874b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:02:51 +0200 Subject: [PATCH 158/224] doc for reader --- chebai_graph/preprocessing/reader/reader.py | 110 ++++++++++++++++---- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/chebai_graph/preprocessing/reader/reader.py b/chebai_graph/preprocessing/reader/reader.py index 15a131d..a63b8a1 100644 --- a/chebai_graph/preprocessing/reader/reader.py +++ b/chebai_graph/preprocessing/reader/reader.py @@ -1,12 +1,10 @@ import os -from typing import List, Optional import chebai.preprocessing.reader as dr import networkx as nx import pysmiles as ps import rdkit.Chem as Chem import torch -from lightning_utilities.core.rank_zero import rank_zero_info, rank_zero_warn from torch_geometric.data import Data as GeomData from torch_geometric.utils import from_networkx @@ -21,34 +19,64 @@ def __init__( self, *args, **kwargs, - ): + ) -> None: + """ + Initialize GraphPropertyReader. + + Args: + *args: Positional arguments forwarded to the base class. + **kwargs: Keyword arguments forwarded to the base class. + """ super().__init__(*args, **kwargs) self.failed_counter = 0 - self.mol_object_buffer = {} + self.mol_object_buffer: dict[str, Chem.rdchem.Mol | None] = {} @classmethod - def name(cls): + def name(cls) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ return "graph_properties" - def _smiles_to_mol(self, smiles: str) -> Optional[Chem.rdchem.Mol]: - """Load smiles into rdkit, store object in buffer""" + def _smiles_to_mol(self, smiles: str) -> Chem.rdchem.Mol | None: + """ + Load SMILES string into an RDKit molecule object and cache it. + + Args: + smiles (str): The SMILES string to parse. + + Returns: + Chem.rdchem.Mol | None: Parsed molecule object or None if parsing failed. + """ if smiles in self.mol_object_buffer: return self.mol_object_buffer[smiles] mol = Chem.MolFromSmiles(smiles) if mol is None: - rank_zero_warn(f"RDKit failed to at parsing {smiles} (returned None)") + print(f"RDKit failed to at parsing {smiles} (returned None)") self.failed_counter += 1 else: try: Chem.SanitizeMol(mol) except Exception as e: - rank_zero_warn(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") + print(f"Rdkit failed at sanitizing {smiles}, \n Error: {e}") self.failed_counter += 1 self.mol_object_buffer[smiles] = mol return mol - def _read_data(self, raw_data): + def _read_data(self, raw_data: str) -> GeomData | None: + """ + Convert raw SMILES string data into a PyTorch Geometric Data object. + + Args: + raw_data (str): SMILES string. + + Returns: + GeomData | None: Graph data object or None if molecule parsing failed. + """ mol = self._smiles_to_mol(raw_data) if mol is None: return None @@ -65,11 +93,24 @@ def _read_data(self, raw_data): return GeomData(x=x, edge_index=edge_index, edge_attr=edge_attr) - def on_finish(self): - rank_zero_info(f"Failed to read {self.failed_counter} SMILES in total") + def on_finish(self) -> None: + """ + Called after reading is done to log information and clean up. + """ + print(f"Failed to read {self.failed_counter} SMILES in total") self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> Optional[List]: + def read_property(self, smiles: str, property: MolecularProperty) -> list | None: + """ + Read a molecular property for a given SMILES string. + + Args: + smiles (str): SMILES string of the molecule. + property (MolecularProperty): Property extractor to apply. + + Returns: + list | None: Property values or None if molecule parsing failed. + """ mol = self._smiles_to_mol(smiles) if mol is None: return None @@ -82,23 +123,45 @@ class GraphReader(dr.ChemDataReader): COLLATOR = GraphCollator - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: + """ + Initialize GraphReader. + + Args: + *args: Positional arguments forwarded to the base class. + **kwargs: Keyword arguments forwarded to the base class. + """ super().__init__(*args, **kwargs) self.dirname = os.path.dirname(__file__) @classmethod - def name(cls): + def name(cls) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ return "graph" - def _read_data(self, raw_data) -> Optional[GeomData]: + def _read_data(self, raw_data: str) -> GeomData | None: + """ + Convert a SMILES string into a PyTorch Geometric Data object with atom tokens and bond order attributes. + + Args: + raw_data (str): SMILES string. + + Returns: + GeomData | None: Graph data object or None if parsing failed. + """ # raw_data is a SMILES string try: mol = ps.read_smiles(raw_data) except ValueError: return None assert isinstance(mol, nx.Graph) - d = {} - de = {} + d: dict[int, int] = {} + de: dict[tuple[int, int], int] = {} for node in mol.nodes: n = mol.nodes[node] try: @@ -127,5 +190,14 @@ def _read_data(self, raw_data) -> Optional[GeomData]: data = from_networkx(mol) return data - def collate(self, list_of_tuples): + def collate(self, list_of_tuples: list) -> any: + """ + Collate a list of samples into a batch. + + Args: + list_of_tuples (list): List of data tuples to collate. + + Returns: + Any: Collated batch (type depends on collator). + """ return self.collator(list_of_tuples) From c1faf989263a7ee7fe7be28a6f808766f8a7ddca Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:32:11 +0200 Subject: [PATCH 159/224] doc for aug reader --- .../preprocessing/reader/augmented_reader.py | 201 +++++++++++++----- 1 file changed, 152 insertions(+), 49 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index c11df8f..59f15bd 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -1,7 +1,6 @@ import re import sys from abc import ABC -from typing import Any, Optional import torch from chebai.preprocessing.reader import DataReader @@ -32,7 +31,7 @@ class _AugmentorReader(DataReader, ABC): COLLATOR = GraphCollator - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """ Initializes the augmentor reader and sets up the failure counter and molecule cache. @@ -42,12 +41,12 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) # Record number of failures when constructing molecule from smiles - self.f_cnt_for_smiles = 0 + self.f_cnt_for_smiles: int = 0 # Record number of failure during augmented graph construction - self.f_cnt_for_aug_graph = 0 - self.mol_object_buffer = {} - self._idx_of_node = 0 - self._idx_of_edge = 0 + self.f_cnt_for_aug_graph: int = 0 + self.mol_object_buffer: dict[str, dict] = {} + self._idx_of_node: int = 0 + self._idx_of_edge: int = 0 @classmethod def name(cls) -> str: @@ -67,7 +66,11 @@ def _read_data(self, smiles: str) -> GeomData | None: smiles (str): SMILES representation of the molecule. Returns: - GeomData: A PyTorch Geometric Data object with augmented nodes and edges. + GeomData | None: A PyTorch Geometric Data object with augmented nodes and edges, + or None if parsing or augmentation fails. + + Raises: + RuntimeError: If an unexpected error occurs during graph augmentation. """ mol = self._smiles_to_mol(smiles) if mol is None: @@ -122,7 +125,7 @@ def _read_data(self, smiles: str) -> GeomData | None: is_atom_node=is_atom_mask, ) - def _smiles_to_mol(self, smiles: str) -> Chem.Mol: + def _smiles_to_mol(self, smiles: str) -> Chem.Mol | None: """ Converts a SMILES string to an RDKit molecule object. Sanitizes the molecule. @@ -130,7 +133,7 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol: smiles (str): SMILES string representing the molecule. Returns: - Chem.Mol: RDKit molecule object. + Chem.Mol | None: RDKit molecule object if successful, else None. """ mol = Chem.MolFromSmiles(smiles) if mol is None: @@ -142,19 +145,26 @@ def _smiles_to_mol(self, smiles: str) -> Chem.Mol: except Exception as e: print(f"RDKit failed at sanitizing {smiles}, Error {e}") self.f_cnt_for_smiles += 1 + mol = None return mol - def _create_augmented_graph(self, mol: Chem.Mol) -> tuple[torch.Tensor, dict]: + def _create_augmented_graph( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict] | None: """ - Generates an augmented graph from a SMILES string. + Generates an augmented graph from a molecule. Args: mol (Chem.Mol): A molecule generated by RDKit. Returns: - Tuple[torch.Tensor, dict]: - - Augmented graph edge index, - - Augmented graph (nodes and edges). + tuple[torch.Tensor, dict] | None: + - Augmented graph edge index tensor, + - Augmented graph data dictionary (nodes and edges), + or None if augmentation fails. + + Raises: + ValueError: If directed_edge_index shape is incorrect. """ augmented_mol = self._augment_graph_structure(mol) @@ -191,9 +201,10 @@ def _augment_graph_structure(self, mol: Chem.Mol) -> dict: Returns: dict: A dictionary containing: - - Augmented graph edge index, - - Augmented graph node attributes - - Augmented graph edge attributes. + - "directed_edge_index" (torch.Tensor): Directed edge index tensor, + - "node_info" (dict): Node attributes dictionary, + - "edge_info" (dict): Edge attributes dictionary, + - "graph_meta_info" (dict): Additional meta information. """ self._idx_of_node = mol.GetNumAtoms() self._idx_of_edge = mol.GetNumBonds() @@ -249,10 +260,10 @@ def _generate_atom_level_edge_index(mol: Chem.Mol) -> torch.Tensor: mol (Chem.Mol): RDKit molecule. Returns: - torch.Tensor: Directed edge index tensor. + torch.Tensor: Directed edge index tensor with shape [2, num_edges]. """ # We need to ensure that directed edges which form a undirected edge are adjacent to each other - edge_index_list = [[], []] + edge_index_list: list[list[int]] = [[], []] for bond in mol.GetBonds(): edge_index_list[0].append(bond.GetBeginAtomIdx()) edge_index_list[1].append(bond.GetEndAtomIdx()) @@ -268,7 +279,7 @@ def on_finish(self) -> None: ) self.mol_object_buffer = {} - def read_property(self, smiles: str, property: MolecularProperty) -> Optional[list]: + def read_property(self, smiles: str, property: MolecularProperty) -> list | None: """ Reads a specific property from a molecule represented by a SMILES string. @@ -277,7 +288,7 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[li property (MolecularProperty): Molecular property object for which the value needs to be extracted. Returns: - Optional[List]: Property values if molecule parsing is successful, else None. + list | None: Property values if molecule parsing is successful, else None. """ if smiles in self.mol_object_buffer: return property.get_property_value(self.mol_object_buffer[smiles]) @@ -295,6 +306,8 @@ def read_property(self, smiles: str, property: MolecularProperty) -> Optional[li class AtomsFGReader_NoFGEdges_NoGraphNode(_AugmentorReader): + """Adds FG nodes without intra-functional group edges and without introducing a graph-level node.""" + def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: @@ -305,10 +318,10 @@ def _augment_graph_structure( mol (Chem.Mol): RDKit molecule object. Returns: - Tuple[torch.Tensor, dict, dict]: - - Augmented graph edge index, - - Augmented graph node attributes - - Augmented graph edge attributes. + Tuple[Tensor, dict, dict]: A tuple containing: + - Augmented graph edge index (Tensor), + - Augmented graph node attributes (dict), + - Augmented graph edge attributes (dict). """ augmented_mol = super()._augment_graph_structure(mol) atom_edge_index = augmented_mol["directed_edge_index"] @@ -359,17 +372,16 @@ def _construct_fg_to_atom_structure( mol (Chem.Mol): RDKit molecule. Returns: - tuple[list[list[int]], dict, dict, dict, list]: A tuple containing: + Tuple[list[list[int]], dict, dict, dict, list]: A tuple containing: - Edge index for FG to atom connections. - - FG node info, - - FG-atom edge attributes, - - FG to atoms mapping, + - FG node info. + - FG-atom edge attributes. + - FG to atoms mapping. - Bonds between FG nodes. Raises: ValueError: If functional groups span multiple ring sizes or if no functional group is assigned to atoms. """ - # Rule-based algorithm to detect functional groups structure, bonds = get_structure(mol) assert structure is not None, "Failed to detect functional groups." @@ -413,7 +425,17 @@ def _construct_fg_to_atom_structure( return fg_atom_edge_index, fg_nodes, atom_fg_edges, fg_to_atoms_map, bonds - def _set_ring_fg_prop(self, connected_atoms, fg_nodes): + def _set_ring_fg_prop(self, connected_atoms: list, fg_nodes: dict) -> None: + """ + Sets ring functional group properties. + + Args: + connected_atoms (list): List of atoms in the ring. + fg_nodes (dict): Dictionary to store FG node attributes. + + Raises: + ValueError: If an atom in the ring does not have a ring size set. + """ # FG atoms have ring size, which indicates the FG is a Ring or Fused Rings ring_size = len(connected_atoms) fg_nodes[self._idx_of_node] = { @@ -433,7 +455,21 @@ def _set_ring_fg_prop(self, connected_atoms, fg_nodes): atom.SetProp("FG", f"RING_{max_ring_size}") atom.SetProp("is_alkyl", "0") - def _set_fg_prop(self, connected_atoms, fg_nodes, fg_smiles): + def _set_fg_prop( + self, connected_atoms: list, fg_nodes: dict, fg_smiles: str + ) -> None: + """ + Sets non-ring functional group properties. + + Args: + connected_atoms (list): Atoms in the FG. + fg_nodes (dict): Dictionary to store FG node attributes. + fg_smiles (str): SMILES of the FG. + + Raises: + ValueError: If functional group assignment is inconsistent or missing. + AssertionError: If no representative atom is found. + """ fg_set = {atom.GetProp("FG") for atom in connected_atoms} if not fg_set: raise ValueError( @@ -482,9 +518,20 @@ def _set_fg_prop(self, connected_atoms, fg_nodes, fg_smiles): class AtomFGReader_WithFGEdges_NoGraphNode(AtomsFGReader_NoFGEdges_NoGraphNode): + """Adds FG nodes with intra-functional group edges and without introducing a graph-level node.""" + def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the molecule graph with intra-functional group edges. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph with FG-level edges. + """ augmented_struct = super()._augment_graph_structure(mol) graph_meta_info = augmented_struct["graph_meta_info"] @@ -525,23 +572,23 @@ def _construct_fg_level_structure( bonds (list): List of bond tuples (source, target, ...). Returns: - Tuple[List[List[int]], dict]: - - Edge index within fg nodes - - Edge attributes for edges within fg nodes. + tuple[list[list[int]], dict]: + - Edge index within FG nodes. + - Edge attributes for edges within FG nodes. """ internal_fg_edges = {} internal_edge_index = [[], []] - def add_fg_internal_edge(source_fg, target_fg): + def add_fg_internal_edge(source_fg: int, target_fg: int) -> None: assert ( source_fg is not None and target_fg is not None ), "Each bond should have a fg node on both end" - assert source_fg != target_fg, "Source and Target FG should be different" + assert source_fg != target_fg, "Source and Target FG should be different" edge_key = tuple(sorted((source_fg, target_fg))) edge_str = f"{edge_key[0]}_{edge_key[1]}" if edge_str not in internal_fg_edges: - # If two atoms of a FG points to atom(s) belonging to another FG. In this case, only one edge is counted. + # If two atoms of a FG point to atom(s) belonging to another FG, only one edge is counted. # Eg. In CHEBI:52723, atom idx 13 and 16 of a FG points to atom idx 18 of another FG internal_edge_index[0].append(source_fg) internal_edge_index[1].append(target_fg) @@ -579,7 +626,18 @@ def add_fg_internal_edge(source_fg, target_fg): class _AddGraphNode(_AugmentorReader): - def _read_data(self, smiles): + """Adds a graph-level node and connects it to selected/given nodes.""" + + def _read_data(self, smiles: str) -> GeomData | None: + """ + Reads data and adds a graph-level node annotation. + + Args: + smiles (str): SMILES string. + + Returns: + Data | None: Geometric data object with is_graph_node annotation. + """ geom_data = super()._read_data(smiles) if geom_data is None: return None @@ -592,8 +650,18 @@ def _read_data(self, smiles): def _add_graph_node_and_edges_to_nodes( self, augmented_struct: dict, - nodes_ids: dict[int, Any] | set[int], + nodes_ids: dict[int, object] | set[int], ) -> tuple[torch.Tensor, dict, dict]: + """ + Adds a graph-level node and connects it to given nodes. + + Args: + augmented_struct (dict): Current graph structure. + nodes_ids (dict[int, object] | set[int]): Node indices to connect to the graph-level node. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure with graph node edges and metadata. + """ nodes_graph_edge_index, graph_node, nodes_to_graph_edges = ( self._construct_nodes_to_graph_node_structure(nodes_ids) ) @@ -620,19 +688,19 @@ def _add_graph_node_and_edges_to_nodes( return augmented_struct def _construct_nodes_to_graph_node_structure( - self, nodes_ids: dict + self, nodes_ids: dict[int, object] | set[int] ) -> tuple[list[list[int]], dict, dict]: """ - Constructs edges between functional group nodes and a global graph-level node. + Constructs edges between selected nodes and a global graph-level node. Args: - fg_to_atoms_map (dict): Mapping from FG ID to atom indices. + nodes_ids (dict[int, object] | set[int]): IDs of nodes to connect to the graph-level node. Returns: - Tuple[List[List[int]], dict, dict]: - - Graph to FG Edge index - - Graph-level node attribute - - FG to Graph Edge attributes + tuple[list[list[int]], dict, dict]: + - Edge index connecting nodes to graph node. + - Graph-level node attributes. + - Edge attributes for graph-level connections. """ graph_node = {k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} @@ -654,9 +722,20 @@ def _construct_nodes_to_graph_node_structure( class AtomFGReader_WithFGEdges_WithGraphNode( AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode ): + """Adds FG nodes with intra-functional group edges and a graph-level node.""" + def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with FG edges and a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ augmented_struct = super()._augment_graph_structure(mol) fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] return self._add_graph_node_and_edges_to_nodes( @@ -667,9 +746,20 @@ def _augment_graph_structure( class AtomFGReader_NoFGEdges_WithGraphNode( AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode ): + """Adds FG nodes without functional group edges and a graph-level node.""" + def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with only a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ augmented_struct = super()._augment_graph_structure(mol) fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] return self._add_graph_node_and_edges_to_nodes( @@ -678,7 +768,20 @@ def _augment_graph_structure( class AtomReader_WithGraphNodeOnly(_AddGraphNode): - def _augment_graph_structure(self, mol): + """Adds a graph-level node and connects it to all atom nodes.""" + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph by adding a graph-level node connected to all atoms. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ augmented_struct = super()._augment_graph_structure(mol) molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} From 68f9f7af31a61e68ed55cf452569ce1b8b659be8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:33:28 +0200 Subject: [PATCH 160/224] doc for vis --- .../preprocessing/utils/visualize_augmented_molecule.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index a75440d..897422d 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -477,6 +477,12 @@ def plot( - simple : 2D graph with all nodes on same plane - h: Hierarchical 2D-graph with separate plane for each node type - 3d: Hierarchical 3D-graph + reader (str): Reader type for graph augmentation. Options: + - 'n_fge_w_gn': FG nodes without FG edges, with a graph node. + - 'w_fge_w_gn': FG nodes with FG edges, with a graph node. + - 'w_fge_n_gn': FG nodes with FG edges, no graph node. + - 'n_fge_n_gn': FG nodes without FG edges, no graph node. + - 'atom_w_gn': Atom nodes only, connected to a graph node. """ fg_reader = READER[reader]() mol = fg_reader._smiles_to_mol(smiles) From 9914b0d7ccafca50b45e9fea497fad44e3cc6092 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:46:07 +0200 Subject: [PATCH 161/224] doc for graph collator --- chebai_graph/preprocessing/collate.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/chebai_graph/preprocessing/collate.py b/chebai_graph/preprocessing/collate.py index 4be36cf..53bdde5 100644 --- a/chebai_graph/preprocessing/collate.py +++ b/chebai_graph/preprocessing/collate.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from chebai.preprocessing.collate import RaggedCollator from torch_geometric.data import Data as GeomData @@ -9,19 +7,27 @@ class GraphCollator(RaggedCollator): + """Collates a batch of molecular graph data with label handling and edge consistency.""" + def __call__(self, data): - loss_kwargs: Dict = dict() + loss_kwargs: dict = {} + # Unpack labels and optional identifiers y, idents = zip(*((d["labels"], d.get("ident")) for d in data)) + + # Replace labels with `y` inside graph features and collect them merged_data = [] for row in data: row["features"].y = row["labels"] merged_data.append(row["features"]) - # add empty edge_attr to avoid problems during collate (only relevant for molecules without edges) + + # Add empty edge_attr for graphs with no edges to prevent PyG errors for mdata in merged_data: - for i, store in enumerate(mdata.stores): + for store in mdata.stores: if "edge_attr" not in store: store["edge_attr"] = torch.tensor([]) + + # Ensure all attributes are float tensors to prevent torch.cat dtype issues for attr in merged_data[0].keys(): for data in merged_data: for store in data.stores: @@ -34,26 +40,35 @@ def __call__(self, data): else: store[attr] = torch.tensor(store[attr], dtype=torch.float32) + # Use PyG's batch collate for graph data x = graph_collate( GeomData, merged_data, follow_batch=["x", "edge_attr", "edge_index", "label"], ) + + # Handle various combinations of missing or available labels if any(x is not None for x in y): + # If any label is not None: (None, None, `1`, None) if any(x is None for x in y): + # If any label is None: (`None`, `None`, 1, `None`) non_null_labels = [i for i, r in enumerate(y) if r is not None] y = self.process_label_rows( tuple(ye for i, ye in enumerate(y) if i in non_null_labels) ) loss_kwargs["non_null_labels"] = non_null_labels else: + # If all labels are not None: (`0`, `2`, `1`, `3`) y = self.process_label_rows(y) else: + # If all labels are None: e.g., (None, None, None, None) y = None loss_kwargs["non_null_labels"] = [] + # Set node features (x) to long dtype (e.g., for categorical features) x[0].x = x[0].x.to(dtype=torch.int64) # x is a Tuple[BaseData, Mapping, Mapping] + return XYGraphData( x, y, From b77d1076b159f12bb6ae32bc028892fb11081ea8 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:49:18 +0200 Subject: [PATCH 162/224] doc for encoder --- .../preprocessing/property_encoder.py | 154 ++++++++++++++---- 1 file changed, 124 insertions(+), 30 deletions(-) diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index 5d6b386..aff1bde 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -3,41 +3,69 @@ import os import sys from itertools import islice -from typing import Optional import torch class PropertyEncoder(abc.ABC): - def __init__(self, property, **kwargs): + """ + Abstract base class for encoding property values. + + Args: + property: The property object associated with this encoder. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, property, **kwargs) -> None: self.property = property - self._encoding_length = 1 + self._encoding_length: int = 1 @property - def name(self): + def name(self) -> str: + """Name of the encoder.""" return "" def get_encoding_length(self) -> int: + """Return the length of the encoding vector.""" return self._encoding_length def set_encoding_length(self, encoding_length: int) -> None: + """Set the length of the encoding vector.""" self._encoding_length = encoding_length - def encode(self, value): + def encode(self, value) -> torch.Tensor: + """ + Encode the given value. + + Args: + value: The value to encode. + + Returns: + Encoded tensor. + """ return value - def on_start(self, **kwargs): + def on_start(self, **kwargs) -> None: + """Hook called at the start of encoding process.""" pass - def on_finish(self): + def on_finish(self) -> None: + """Hook called at the end of encoding process.""" return class IndexEncoder(PropertyEncoder): - """Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have - occurred. Stores this list in a file for later reference.""" + """ + Encodes property values as indices. For that purpose, compiles a dynamic list of different values that have + occurred. Stores this list in a file for later reference. + + Args: + property: The property object. + indices_dir: Optional directory to store index files. + **kwargs: Additional keyword arguments. + """ - def __init__(self, property, indices_dir=None, **kwargs): + def __init__(self, property, indices_dir: str | None = None, **kwargs) -> None: super().__init__(property, **kwargs) if indices_dir is None: indices_dir = os.path.dirname(inspect.getfile(self.__class__)) @@ -53,12 +81,17 @@ def __init__(self, property, indices_dir=None, **kwargs): self.offset = 1 @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "index" @property - def index_path(self): - """Get path to store indices of property values, create file if it does not exist yet""" + def index_path(self) -> str: + """Get path to store indices of property values, create file if it does not exist yet + + Returns: + Path to index file. + """ index_path = os.path.join( self.dirname, "bin", self.property.name, f"indices_{self.name}.txt" ) @@ -70,8 +103,12 @@ def index_path(self): pass return index_path - def on_finish(self): - """Save cache""" + def on_finish(self) -> None: + """ + Save cache + + Saves new tokens added to the cache to the index file and logs count of unknown tokens. + """ total_tokens = len(self.cache) if total_tokens > self.index_length_start: print("New tokens added to the cache, Saving them to index token file.....") @@ -99,21 +136,36 @@ def on_finish(self): f"{self.__class__.__name__} Encountered {self._count_for_unk_token} unknown tokens" ) - def encode(self, token): - """Returns a unique number for each token, automatically adds new tokens to the cache.""" + def encode(self, token: str | None) -> torch.Tensor: + """ + Returns a unique number for each token, automatically adds new tokens to the cache. + + Args: + token: The token to encode. + + Returns: + A tensor containing the encoded index. + """ if token is None: self._count_for_unk_token += 1 return torch.tensor([self._unk_token_idx]) if str(token) not in self.cache: - self.cache[(str(token))] = len(self.cache) + self.cache[str(token)] = len(self.cache) return torch.tensor([self.cache[str(token)] + self.offset]) class OneHotEncoder(IndexEncoder): - """Returns one-hot encoding of the value (position in one-hot vector is defined by index).""" + """ + Returns one-hot encoding of the value (position in one-hot vector is defined by index). - def __init__(self, property, n_labels: Optional[int] = None, **kwargs): + Args: + property: The property object. + n_labels: Optional number of labels for encoding. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, property, n_labels: int | None = None, **kwargs) -> None: super().__init__(property, **kwargs) self._encoding_length = n_labels # To undo any offset set by index encoder as its not relevant for one-hot-encoder (no offset needed for some unknown/reserved token) @@ -121,14 +173,21 @@ def __init__(self, property, n_labels: Optional[int] = None, **kwargs): self.offset = 0 def get_encoding_length(self) -> int: + """Return the number of classes for one-hot encoding.""" return self._encoding_length or len(self.cache) @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "one_hot" - def on_start(self, property_values): - """To get correct number of classes during encoding, cache unique tokens beforehand""" + def on_start(self, property_values: list[list[str | None]]) -> None: + """ + To get correct number of classes during encoding, cache unique tokens beforehand + + Args: + property_values: List of property value sequences. + """ unique_tokens = list( dict.fromkeys( [ @@ -140,11 +199,20 @@ def on_start(self, property_values): ] ) ) - self.tokens_dict = {} + self.tokens_dict: dict[str, torch.Tensor] = {} for token in unique_tokens: self.tokens_dict[token] = super().encode(token) - def encode(self, token): + def encode(self, token: str | None) -> torch.Tensor: + """ + Returns one-hot encoded tensor for the token. + + Args: + token: The token to encode. + + Returns: + One-hot encoded tensor of shape (1, encoding_length). + """ if token not in self.tokens_dict: self._count_for_unk_token += 1 return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64) @@ -155,22 +223,48 @@ def encode(self, token): class AsIsEncoder(PropertyEncoder): - """Returns the input value as it is, useful e.g. for float values.""" + """ + Returns the input value as it is, useful e.g. for float values. + """ @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "asis" - def encode(self, token): + def encode(self, token: float | int | None) -> torch.Tensor: + """ + Return the input value as tensor, or zero tensor if None. + + Args: + token: The value to encode. + + Returns: + Tensor of shape (1,) containing the input value or zero. + """ if token is None: return torch.tensor([0]) return torch.tensor([token]) class BoolEncoder(PropertyEncoder): + """ + Encodes boolean values as 0 or 1. + """ + @property - def name(self): + def name(self) -> str: + """Name of this encoder.""" return "bool" - def encode(self, token: bool): + def encode(self, token: bool) -> torch.Tensor: + """ + Encode boolean token as tensor. + + Args: + token: Boolean value. + + Returns: + Tensor with 1 if True else 0. + """ return torch.tensor([1 if token else 0]) From b3b3a694e886296392bed0d7f236247491099edb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 5 Jul 2025 20:52:30 +0200 Subject: [PATCH 163/224] doc of structures --- chebai_graph/preprocessing/structures.py | 27 ++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/structures.py b/chebai_graph/preprocessing/structures.py index b417593..665c605 100644 --- a/chebai_graph/preprocessing/structures.py +++ b/chebai_graph/preprocessing/structures.py @@ -1,11 +1,34 @@ +import torch from chebai.preprocessing.structures import XYData class XYGraphData(XYData): - def __len__(self): + """ + Extension of XYData supporting `.to(device)` for potentially complex `x` structures. + + `x` can be: + - a tensor, + - a tuple of tensors or dicts of tensors, + and this class recursively sends all tensors to the specified device. + + Args: + Inherits from XYData. + """ + + def __len__(self) -> int: + """Return the length of y.""" return len(self.y) - def to_x(self, device): + def to_x(self, device: torch.device | str) -> object: + """ + Move the input features `x` to the given device. + + Args: + device: torch device or device string (e.g. 'cpu' or 'cuda'). + + Returns: + The input `x` moved to the specified device, preserving structure. + """ if isinstance(self.x, tuple): res = [] for elem in self.x: From 8f0a171493162604cc39984a8990ab20e9b28a2e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 20:12:14 +0200 Subject: [PATCH 164/224] lint and pre-commit version should match --- .github/workflows/lint.yml | 2 +- .pre-commit-config.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1b63c41..bb9154f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black ruff + pip install black==25.1.0 ruff==0.12.2 - name: Run Black run: black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e32d80c..cbb7284 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: "24.2.0" + rev: "25.1.0" hooks: - id: black - id: black-jupyter # for formatting jupyter-notebook @@ -25,7 +25,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.1 + rev: v0.12.2 hooks: - id: ruff - args: [] # No --fix, disables formatting + args: [--fix] From 98f007f575a7802bb3353e38e9059944c8c84c5f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 23:01:12 +0200 Subject: [PATCH 165/224] doc model base --- chebai_graph/models/base.py | 432 +++++++++++++++++++++++------------- 1 file changed, 283 insertions(+), 149 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 016cbe3..b73158a 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -8,17 +8,54 @@ class GraphBaseNet(ChebaiBaseNet, ABC): - def _get_prediction_and_labels(self, data, labels, output): + """ + Base class for graph-based prediction networks. + """ + + def _get_prediction_and_labels( + self, data: XYData, labels: torch.Tensor, output: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply sigmoid activation to outputs and return processed labels. + + Args: + data (XYData): Input batch data. + labels (torch.Tensor): Ground-truth labels. + output (torch.Tensor): Raw model output. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of (predictions, labels). + """ return torch.sigmoid(output), labels.int() - def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: + def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor | None: + """ + Process labels from XYData batch. + + Returns: + torch.Tensor | None: Processed labels if present, else None. + """ return batch.y.float() if batch.y is not None else None class GraphModelBase(torch.nn.Module, ABC): - """Base class for graph-based models with a configuration dictionary.""" - - def __init__(self, config: dict, **kwargs): + """ + Abstract base class for graph models with configurable architecture. + """ + + def __init__(self, config: dict, **kwargs) -> None: + """ + Initialize model hyperparameters from configuration. + + Args: + config (dict): Configuration dictionary with keys: + - 'hidden_length' + - 'dropout_rate' + - 'n_conv_layers' + - 'n_atom_properties' + - 'n_bond_properties' + **kwargs: Additional keyword arguments for torch.nn.Module. + """ super().__init__(**kwargs) self.hidden_length = int(config["hidden_length"]) self.dropout_rate = float(config["dropout_rate"]) @@ -28,7 +65,22 @@ def __init__(self, config: dict, **kwargs): class GraphNetWrapper(GraphBaseNet, ABC): - def __init__(self, config: dict, n_linear_layers, n_molecule_properties, **kwargs): + """ + Base wrapper class for GNNs with linear layers for property prediction. + """ + + def __init__( + self, config: dict, n_linear_layers: int, n_molecule_properties: int, **kwargs + ) -> None: + """ + Initialize the GNN and linear layers. + + Args: + config (dict): Model configuration. + n_linear_layers (int): Number of linear layers. + n_molecule_properties (int): Number of molecular-level features. + **kwargs: Additional arguments. + """ super().__init__(**kwargs) self.gnn = self._get_gnn(config) gnn_out_dim = ( @@ -49,13 +101,52 @@ def __init__(self, config: dict, n_linear_layers, n_molecule_properties, **kwarg ) @abstractmethod - def _get_gnn(self, config): + def _get_gnn(self, config: dict) -> torch.nn.Module: + """ + Create the graph neural network. + + Args: + config (dict): Configuration dictionary. + + Returns: + torch.nn.Module: Instantiated GNN module. + """ pass - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension for the linear layers. + + Args: + gnn_out_dim (int): Output dimension of GNN. + n_molecule_properties (int): Number of molecule-level features. + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties - def _get_linear_module_list(self, n_linear_layers, in_dim, hidden_dim, out_dim): + def _get_linear_module_list( + self, + n_linear_layers: int, + in_dim: int, + hidden_dim: int, + out_dim: int, + ) -> torch.nn.Sequential: + """ + Construct a sequential module of linear layers. + + Args: + n_linear_layers (int): Number of linear layers. + in_dim (int): Input dimension. + hidden_dim (int): Hidden dimension. + out_dim (int): Output dimension. + + Returns: + torch.nn.Sequential: Linear layers with activations. + """ if n_linear_layers < 1: raise ValueError("n_linear_layers must be at least 1") @@ -72,7 +163,16 @@ def _get_linear_module_list(self, n_linear_layers, in_dim, hidden_dim, out_dim): return torch.nn.Sequential(*layers) - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through GNN, pooling and linear layers. + + Args: + batch (dict): Input batch with graph features. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) a = self.gnn(batch) @@ -82,127 +182,134 @@ def forward(self, batch): class AugmentedNodePoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # atom_embeddings + molecule attributes + augmented_node_embeddings + """ + Pooling using atom and augmented node embeddings. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Include augmented node embeddings in input dimension. + - atom_embeddings + molecule attributes + augmented_node_embeddings + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties + gnn_out_dim - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass with separate pooling for atom and augmented nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) - is_atom_node = graph_data.is_atom_node.bool() # Boolean mask: shape [num_nodes] + is_atom_node = graph_data.is_atom_node.bool() is_augmented_node = ~is_atom_node node_embeddings = self.gnn(batch) - atoms_embeddings = node_embeddings[is_atom_node] atoms_batch = graph_data.batch[is_atom_node] augmented_nodes_embeddings = node_embeddings[is_augmented_node] augmented_nodes_batch = graph_data.batch[is_augmented_node] - # Scatter add separately atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) aug_nodes_vec = scatter_add( augmented_nodes_embeddings, augmented_nodes_batch, dim=0 ) - # Concatenate all graph_vector = torch.cat( - [ - atoms_vec, - graph_data.molecule_attr, - aug_nodes_vec, - ], - dim=1, + [atoms_vec, graph_data.molecule_attr, aug_nodes_vec], dim=1 ) - return self.lin_sequential(graph_vector) class GraphNodePoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # all_nodes_embeddings_except_graph_node + molecule attributes + graph_node_embedding + """ + Pooling using non-graph nodes and graph node embeddings. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension including graph node embeddings. + - all_nodes_embeddings_except_graph_node + molecule attributes + graph_node_embedding + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties + gnn_out_dim - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass with separate pooling for graph and other nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) is_graph_node = graph_data.is_graph_node.bool() is_not_graph_node = ~is_graph_node node_embeddings = self.gnn(batch) - graph_node_embedding = node_embeddings[is_graph_node] graph_node_batch = graph_data.batch[is_graph_node] remaining_nodes_embedding = node_embeddings[is_not_graph_node] remaining_nodes_batch = graph_data.batch[is_not_graph_node] - # Scatter add separately graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) remaining_nodes_vec = scatter_add( remaining_nodes_embedding, remaining_nodes_batch, dim=0 ) - # Concatenate all graph_vector = torch.cat( - [ - remaining_nodes_vec, - graph_data.molecule_attr, - graph_node_vec, - ], - dim=1, + [remaining_nodes_vec, graph_data.molecule_attr, graph_node_vec], dim=1 ) - return self.lin_sequential(graph_vector) -class FGNodePoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # all_nodes_embeddings_except_fg_nodes + molecule attributes + fg_node_embedding +class FGNodePoolingNoGraphNodeNet(GraphNetWrapper, ABC): + """ + Graph Node not considered here in any computation. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension including: + - atom_embeddings + - molecule attributes + - functional_group_node_embeddings + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties + gnn_out_dim - def forward(self, batch): - graph_data = batch["features"][0] - assert isinstance(graph_data, GraphData) - is_graph_node = graph_data.is_graph_node.bool() - is_atom_node = graph_data.is_atom_node.bool() - is_fg_node = (~is_atom_node) & (~is_graph_node) - is_remaining_node = ~is_fg_node - - node_embeddings = self.gnn(batch) - - remaining_nodes_embedding = node_embeddings[is_remaining_node] - remaining_nodes_batch = graph_data.batch[is_remaining_node] - - fg_nodes_embeddings = node_embeddings[is_fg_node] - fg_nodes_batch = graph_data.batch[is_fg_node] - - # Scatter add separately - remaining_nodes_vec = scatter_add( - remaining_nodes_embedding, remaining_nodes_batch, dim=0 - ) - fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) - - # Concatenate all - graph_vector = torch.cat( - [ - remaining_nodes_vec, - graph_data.molecule_attr, - fg_nodes_vec, - ], - dim=1, - ) - - return self.lin_sequential(graph_vector) - + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling atoms and functional group nodes. + Graph nodes are ignored. -class GraphNodeFGNodePoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # atom_embeddings + molecule attributes + functional_group_node_embeddings + graph_node_embeddings - return gnn_out_dim + n_molecule_properties + gnn_out_dim + gnn_out_dim + Args: + batch (dict): Input batch. - def forward(self, batch): + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) is_graph_node = graph_data.is_graph_node.bool() @@ -211,81 +318,52 @@ def forward(self, batch): node_embeddings = self.gnn(batch) - graph_node_embedding = node_embeddings[is_graph_node] - graph_node_batch = graph_data.batch[is_graph_node] - atoms_embeddings = node_embeddings[is_atom_node] atoms_batch = graph_data.batch[is_atom_node] fg_nodes_embeddings = node_embeddings[is_fg_node] fg_nodes_batch = graph_data.batch[is_fg_node] - # Scatter add separately - graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) - # Concatenate all graph_vector = torch.cat( - [ - atoms_vec, - graph_data.molecule_attr, - fg_nodes_vec, - graph_node_vec, - ], - dim=1, + [atoms_vec, graph_data.molecule_attr, fg_nodes_vec], dim=1 ) return self.lin_sequential(graph_vector) -class FGNodePoolingNoGraphNodeNet(GraphNetWrapper, ABC): - """Graph Node not considered here in any computation""" - - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # atom_embeddings + molecule attributes + functional_group_node_embeddings +class GraphNodeNoFGNodePoolingNet(GraphNetWrapper, ABC): + """ + Functional Group Nodes not considered here in any computation. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Compute input dimension including: + - atom_embeddings + - molecule attributes + - graph_node_embeddings + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties + gnn_out_dim - def forward(self, batch): - graph_data = batch["features"][0] - assert isinstance(graph_data, GraphData) - is_graph_node = graph_data.is_graph_node.bool() - is_atom_node = graph_data.is_atom_node.bool() - is_fg_node = (~is_atom_node) & (~is_graph_node) - - node_embeddings = self.gnn(batch) - - atoms_embeddings = node_embeddings[is_atom_node] - atoms_batch = graph_data.batch[is_atom_node] + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling atoms and graph nodes. + Functional group nodes are ignored. - fg_nodes_embeddings = node_embeddings[is_fg_node] - fg_nodes_batch = graph_data.batch[is_fg_node] + Args: + batch (dict): Input batch. - # Scatter add separately - atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) - fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) - - # Concatenate all - graph_vector = torch.cat( - [ - atoms_vec, - graph_data.molecule_attr, - fg_nodes_vec, - ], - dim=1, - ) - - return self.lin_sequential(graph_vector) - - -class GraphNodeNoFGNodePoolingNet(GraphNetWrapper, ABC): - """Functional Group Nodes not considered here in any computation""" - - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): - # atom_embeddings + molecule attributes + graph_node_embeddings - return gnn_out_dim + n_molecule_properties + gnn_out_dim - - def forward(self, batch): + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) is_graph_node = graph_data.is_graph_node.bool() @@ -299,28 +377,42 @@ def forward(self, batch): atoms_embeddings = node_embeddings[is_atom_node] atoms_batch = graph_data.batch[is_atom_node] - # Scatter add separately graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) - # Concatenate all graph_vector = torch.cat( - [ - atoms_vec, - graph_data.molecule_attr, - graph_node_vec, - ], - dim=1, + [atoms_vec, graph_data.molecule_attr, graph_node_vec], dim=1 ) return self.lin_sequential(graph_vector) class AugmentedOnlyPoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + """ + Only augmented node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only augmented node embeddings. + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only augmented nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] is_atom_node = graph_data.is_atom_node.bool() augmented_nodes_embeddings = self.gnn(batch)[~is_atom_node] @@ -335,10 +427,31 @@ def forward(self, batch): class FGOnlyPoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + """ + Only functional group node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only FG node embeddings. + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only functional group nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] is_graph_node = graph_data.is_graph_node.bool() is_atom_node = graph_data.is_atom_node.bool() @@ -353,10 +466,31 @@ def forward(self, batch): class GraphNodeOnlyPoolingNet(GraphNetWrapper, ABC): - def _get_lin_seq_input_dim(self, gnn_out_dim, n_molecule_properties): + """ + Only graph node embeddings are pooled. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Return input dimension using only graph node embeddings. + + Returns: + int: Total input dimension. + """ return gnn_out_dim + n_molecule_properties - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass pooling only graph nodes. + + Args: + batch (dict): Input batch. + + Returns: + torch.Tensor: Predicted output. + """ graph_data = batch["features"][0] is_graph_node = graph_data.is_graph_node.bool() From 61473e2711f40f3af7411fa8097854c42cb9a8cf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 23:57:08 +0200 Subject: [PATCH 166/224] doc for models augmented --- chebai_graph/models/augmented.py | 72 ++++++++--- chebai_graph/models/base.py | 205 +++++++++++++++++++++++++++++-- 2 files changed, 252 insertions(+), 25 deletions(-) diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index 35f72aa..b1cae5b 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -13,60 +13,96 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedAugNodePoolGraphPred" + ... class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - GraphNodePoolingNet: Pools atom and graph node embeddings with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedGraphNodePoolGraphPred" + ... class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - FGNodePoolingNet: Pools functional group nodes and other nodes with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedFGNodePoolGraphPred" + ... class ResGatedGraphNodeFGNodePoolGraphPred( GraphNodeFGNodePoolingNet, ResGatedGraphPred ): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedGraphNodeFGNodePoolGraphPred" + ... class ResGatedGraphNodeNoFGNodeGraphPred( GraphNodeNoFGNodePoolingNet, ResGatedGraphPred ): - """GNN for graph-level prediction for augmented graphs without FG nodes""" + """ + Combines: + - GraphNodeNoFGNodePoolingNet: Pools atom and graph nodes, excluding functional groups. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedGraphNodeNoFGNodeGraphPred" + ... class ResGatedFGNodeNoGraphNodeGraphPred( FGNodePoolingNoGraphNodeNet, ResGatedGraphPred ): - """GNN for graph-level prediction for augmented graphs without FG nodes""" + """ + Combines: + - FGNodePoolingNoGraphNodeNet: Pools atom and functional group nodes, excluding graph nodes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedFGNodeNoGraphNodeGraphPred" + ... class ResGatedAugOnlyPoolGraphPred(AugmentedOnlyPoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - AugmentedOnlyPoolingNet: Pools only augmented nodes with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedAugOnlyPoolGraphPred" + ... class ResGatedGraphNodeOnlyPoolGraphPred(GraphNodeOnlyPoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - GraphNodeOnlyPoolingNet: Pools only graph nodes with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedGraphNodeOnlyPoolGraphPred" + ... class ResGatedFGOnlyPoolGraphPred(FGOnlyPoolingNet, ResGatedGraphPred): - """GNN for graph-level prediction for augmented graphs""" + """ + Combines: + - FGOnlyPoolingNet: Pools only functional group nodes with molecule attributes. + - ResGatedGraphPred: Residual gated network for final graph prediction. + """ - NAME = "ResGatedFGOnlyPoolGraphPred" + ... diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index b73158a..56eed31 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -183,50 +183,241 @@ def forward(self, batch: dict) -> torch.Tensor: class AugmentedNodePoolingNet(GraphNetWrapper, ABC): """ - Pooling using atom and augmented node embeddings. + A pooling network that aggregates: + - Atom node embeddings + - Molecular attributes (if provided else skipped) + - Augmented node embeddings (FG nodes and graph node) + + The concatenated vector is then passed through a linear sequential block. """ def _get_lin_seq_input_dim( self, gnn_out_dim: int, n_molecule_properties: int ) -> int: """ - Include augmented node embeddings in input dimension. - - atom_embeddings + molecule attributes + augmented_node_embeddings + Compute the input dimension for the final linear sequential block. + + Includes: + - Atom embeddings + - Molecular attributes + - Augmented node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + Returns: - int: Total input dimension. + int: Total input dimension for the linear sequential block. """ return gnn_out_dim + n_molecule_properties + gnn_out_dim def forward(self, batch: dict) -> torch.Tensor: """ - Forward pass with separate pooling for atom and augmented nodes. + Forward pass for pooling node embeddings. + + Steps: + 1. Identify atom nodes and augmented nodes. + 2. Compute node embeddings with the GNN. + 3. Aggregate embeddings for atoms and augmented nodes separately using scatter add. + 4. Concatenate: + - Atom nodes vector + - Molecular attributes + - Augmented nodes vector + 5. Pass the concatenated vector through the linear sequential block. Args: - batch (dict): Input batch. + batch (dict): Input batch containing graph data and features. Returns: - torch.Tensor: Predicted output. + torch.Tensor: Output tensor after pooling and linear transformation. """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) + is_atom_node = graph_data.is_atom_node.bool() is_augmented_node = ~is_atom_node node_embeddings = self.gnn(batch) + atoms_embeddings = node_embeddings[is_atom_node] atoms_batch = graph_data.batch[is_atom_node] augmented_nodes_embeddings = node_embeddings[is_augmented_node] augmented_nodes_batch = graph_data.batch[is_augmented_node] + # Scatter add separately atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) aug_nodes_vec = scatter_add( augmented_nodes_embeddings, augmented_nodes_batch, dim=0 ) + # Concatenate all graph_vector = torch.cat( [atoms_vec, graph_data.molecule_attr, aug_nodes_vec], dim=1 ) + + return self.lin_sequential(graph_vector) + + +class FGNodePoolingNet(GraphNetWrapper, ABC): + """ + A pooling network that pools node embeddings by aggregating: + - All non-functional-group nodes' embeddings (atom and graph node) + - Molecular attributes + - Functional group node embeddings + + The concatenated vector is then passed through a linear sequential block. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Computes the input dimension for the final linear sequential block. + + Combines: + - All nodes embeddings except functional group nodes + - Molecular attributes + - Functional group node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + + Returns: + int: Total input dimension for the linear sequential block. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass for pooling node embeddings. + + Steps: + 1. Identify graph, atom, and functional group nodes. + 2. Aggregate embeddings for remaining nodes and functional group nodes separately. + 3. Concatenate: + - Remaining nodes vector + - Molecular attributes + - Functional group nodes vector + 4. Pass the concatenated vector through the linear sequential block. + + Args: + batch (dict): Batch containing graph data and features. + + Returns: + torch.Tensor: Output tensor after pooling and linear transformation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + is_remaining_node = ~is_fg_node + + node_embeddings = self.gnn(batch) + + remaining_nodes_embedding = node_embeddings[is_remaining_node] + remaining_nodes_batch = graph_data.batch[is_remaining_node] + + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + remaining_nodes_vec = scatter_add( + remaining_nodes_embedding, remaining_nodes_batch, dim=0 + ) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [remaining_nodes_vec, graph_data.molecule_attr, fg_nodes_vec], dim=1 + ) + + return self.lin_sequential(graph_vector) + + +class GraphNodeFGNodePoolingNet(GraphNetWrapper, ABC): + """ + A pooling network that pools node embeddings by aggregating: + - Atom nodes + - Molecular attributes + - Functional group node embeddings + - Graph node embeddings + + The concatenated vector is then passed through a linear sequential block. + """ + + def _get_lin_seq_input_dim( + self, gnn_out_dim: int, n_molecule_properties: int + ) -> int: + """ + Computes the input dimension for the final linear sequential block. + + Combines: + - Atom embeddings + - Molecular attributes + - Functional group node embeddings + - Graph node embeddings + + Args: + gnn_out_dim (int): Dimension of the GNN output per node. + n_molecule_properties (int): Number of molecule-level attributes. + + Returns: + int: Total input dimension for the linear sequential block. + """ + return gnn_out_dim + n_molecule_properties + gnn_out_dim + gnn_out_dim + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass for pooling node embeddings. + + Steps: + 1. Identify graph, atom, and functional group nodes. + 2. Aggregate embeddings for each node type separately. + 3. Concatenate: + - Atom nodes vector + - Molecular attributes + - Functional group nodes vector + - Graph node vector + 4. Pass the concatenated vector through the linear sequential block. + + Args: + batch (dict): Batch containing graph data and features. + + Returns: + torch.Tensor: Output tensor after pooling and linear transformation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + is_graph_node = graph_data.is_graph_node.bool() + is_atom_node = graph_data.is_atom_node.bool() + is_fg_node = (~is_atom_node) & (~is_graph_node) + + node_embeddings = self.gnn(batch) + + graph_node_embedding = node_embeddings[is_graph_node] + graph_node_batch = graph_data.batch[is_graph_node] + + atoms_embeddings = node_embeddings[is_atom_node] + atoms_batch = graph_data.batch[is_atom_node] + + fg_nodes_embeddings = node_embeddings[is_fg_node] + fg_nodes_batch = graph_data.batch[is_fg_node] + + # Scatter add separately + graph_node_vec = scatter_add(graph_node_embedding, graph_node_batch, dim=0) + atoms_vec = scatter_add(atoms_embeddings, atoms_batch, dim=0) + fg_nodes_vec = scatter_add(fg_nodes_embeddings, fg_nodes_batch, dim=0) + + # Concatenate all + graph_vector = torch.cat( + [atoms_vec, graph_data.molecule_attr, fg_nodes_vec, graph_node_vec], dim=1 + ) + return self.lin_sequential(graph_vector) From 823a3338368db5a4a80cdfe266f38c9580a9a55a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 7 Jul 2025 00:02:10 +0200 Subject: [PATCH 167/224] doc model resgated and gat --- chebai_graph/models/gat.py | 56 +++++++++++++++++++++++--- chebai_graph/models/resgated.py | 70 +++++++++++++++++++++++++++------ 2 files changed, 109 insertions(+), 17 deletions(-) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index cb4e2f4..858abc9 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -7,11 +7,28 @@ class GATGraphConvNetBase(GraphModelBase): - def __init__(self, config, **kwargs): + """ + Graph Attention Network (GAT) base module for graph convolution. + + Uses PyTorch Geometric's `GAT` implementation to process atomic node features + and bond edge attributes through multiple attention heads and layers. + """ + + def __init__(self, config: dict, **kwargs): + """ + Initialize the GATGraphConvNetBase. + + Args: + config (dict): Model configuration containing: + - 'heads' (int): Number of attention heads. + - 'v2' (bool): Whether to use the GATv2 variant. + - Other required GraphModelBase parameters. + **kwargs: Additional arguments for the base class. + """ super().__init__(config=config, **kwargs) self.heads = int(config["heads"]) self.v2 = bool(config["v2"]) - self.activation = ELU() # instantiate once + self.activation = ELU() # Instantiate ELU once for reuse. self.gat = GAT( in_channels=self.n_atom_properties, hidden_channels=self.hidden_length, @@ -24,20 +41,49 @@ def __init__(self, config, **kwargs): ) def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through the GAT network. + + Processes atomic node features and edge attributes, and applies + an ELU activation to the output. + + Args: + batch (dict): Input batch containing: + - 'features': A list with a `GraphData` object as its first element. + + Returns: + torch.Tensor: Node embeddings after GAT and activation. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) - a = self.gat( + out = self.gat( x=graph_data.x.float(), edge_index=graph_data.edge_index, edge_attr=graph_data.edge_attr, ) - return self.activation(a) + return self.activation(out) class GATGraphPred(GraphNetWrapper): + """ + GAT-based graph prediction model. + + Uses a `GATGraphConvNetBase` as the GNN backbone for generating node embeddings, + which are then pooled by the GraphNetWrapper for final prediction. + """ + NAME = "GATGraphPred" - def _get_gnn(self, config): + def _get_gnn(self, config: dict) -> GATGraphConvNetBase: + """ + Instantiate the GAT graph convolutional network base. + + Args: + config (dict): Model configuration. + + Returns: + GATGraphConvNetBase: The initialized GNN. + """ return GATGraphConvNetBase(config=config) diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py index 5667be6..9244327 100644 --- a/chebai_graph/models/resgated.py +++ b/chebai_graph/models/resgated.py @@ -8,20 +8,36 @@ class ResGatedGraphConvNetBase(GraphModelBase): - """GNN that supports edge attributes""" + """ + Residual Gated Graph Convolutional Network with edge attributes support. + + This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric, + allowing edge attributes as part of message passing. A final projection layer maps + to the hidden length specified for downstream graph prediction tasks. + """ NAME = "ResGatedGraphConvNetBase" - def __init__(self, config, **kwargs): + def __init__(self, config: dict, **kwargs): + """ + Initialize the ResGatedGraphConvNetBase. + + Args: + config (dict): Configuration dictionary with keys: + - 'in_length' (int): Intermediate feature length used in GNN layers. + - Other parameters inherited from GraphModelBase. + **kwargs: Additional keyword arguments passed to GraphModelBase. + """ super().__init__(config=config, **kwargs) self.in_length = int(config["in_length"]) self.activation = F.elu self.dropout = nn.Dropout(self.dropout_rate) - self.convs = torch.nn.ModuleList([]) + self.convs = torch.nn.ModuleList() for i in range(self.n_conv_layers): if i == 0: + # Initial layer uses atom features as input self.convs.append( tgnn.ResGatedGraphConv( self.n_atom_properties, @@ -30,36 +46,66 @@ def __init__(self, config, **kwargs): edge_dim=self.n_bond_properties, ) ) + # Intermediate layers self.convs.append( tgnn.ResGatedGraphConv( self.in_length, self.in_length, edge_dim=self.n_bond_properties ) ) + + # Final projection layer to hidden dimension self.final_conv = tgnn.ResGatedGraphConv( self.in_length, self.hidden_length, edge_dim=self.n_bond_properties ) - def forward(self, batch): + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through residual gated GNN layers. + + Args: + batch (dict): A batch containing: + - 'features': A list with a `GraphData` instance as the first element. + + Returns: + torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length]. + """ graph_data = batch["features"][0] assert isinstance(graph_data, GraphData) - a = graph_data.x.float() - # a = self.embedding(a) + + x = graph_data.x.float() # Atom features for conv in self.convs: assert isinstance(conv, tgnn.ResGatedGraphConv) - a = self.activation( - conv(a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) + x = self.activation( + conv(x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) ) - a = self.activation( + + x = self.activation( self.final_conv( - a, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr + x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr ) ) - return a + + return x class ResGatedGraphPred(GraphNetWrapper): + """ + Residual Gated GNN for Graph Prediction. + + Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings. + """ + NAME = "ResGatedGraphPred" - def _get_gnn(self, config): + def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase: + """ + Instantiate the residual gated GNN backbone. + + Args: + config (dict): Model configuration. + + Returns: + ResGatedGraphConvNetBase: The GNN encoder. + """ return ResGatedGraphConvNetBase(config=config) From 7e074f5fe1b9b2ea7a084d4584f45974ac7e72da Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 8 Jul 2025 16:39:18 +0200 Subject: [PATCH 168/224] fix GAT model edge-index and act error --- chebai_graph/models/gat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index 858abc9..c289e9e 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -37,7 +37,7 @@ def __init__(self, config: dict, **kwargs): edge_dim=self.n_bond_properties, heads=self.heads, v2=self.v2, - act=ELU, + act=self.activation, ) def forward(self, batch: dict) -> torch.Tensor: @@ -59,7 +59,7 @@ def forward(self, batch: dict) -> torch.Tensor: out = self.gat( x=graph_data.x.float(), - edge_index=graph_data.edge_index, + edge_index=graph_data.edge_index.long(), edge_attr=graph_data.edge_attr, ) From 354f2d1a2738c7f5153a93a9201216a4d92ea71a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 15 Jul 2025 12:43:12 +0200 Subject: [PATCH 169/224] data configs for augmented graphs --- configs/data/chebi50_aug_all_props.yml | 21 +++++++++++++++++++ ...ted_gnn.yml => chebi50_aug_props_only.yml} | 5 ++++- configs/data/chebi50_augmented_baseline.yml | 2 +- 3 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 configs/data/chebi50_aug_all_props.yml rename configs/data/{chebi50_augmented_gnn.yml => chebi50_aug_props_only.yml} (53%) diff --git a/configs/data/chebi50_aug_all_props.yml b/configs/data/chebi50_aug_all_props.yml new file mode 100644 index 0000000..160df0e --- /dev/null +++ b/configs/data/chebi50_aug_all_props.yml @@ -0,0 +1,21 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp +init_args: + properties: + # Atom properties + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.AtomNodeLevel + - chebai_graph.preprocessing.properties.AtomRingSize + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl + # Bond properties + - chebai_graph.preprocessing.properties.AugBondType + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondAromaticity + - chebai_graph.preprocessing.properties.BondLevel diff --git a/configs/data/chebi50_augmented_gnn.yml b/configs/data/chebi50_aug_props_only.yml similarity index 53% rename from configs/data/chebi50_augmented_gnn.yml rename to configs/data/chebi50_aug_props_only.yml index e6482ef..a81e7f1 100644 --- a/configs/data/chebi50_augmented_gnn.yml +++ b/configs/data/chebi50_aug_props_only.yml @@ -1,9 +1,12 @@ -class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp init_args: properties: # Atom properties - chebai_graph.preprocessing.properties.AtomFunctionalGroup - chebai_graph.preprocessing.properties.AtomNodeLevel - chebai_graph.preprocessing.properties.AtomRingSize + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl # Bond properties - chebai_graph.preprocessing.properties.BondLevel diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml index cab8b56..91be1fa 100644 --- a/configs/data/chebi50_augmented_baseline.yml +++ b/configs/data/chebi50_augmented_baseline.yml @@ -10,4 +10,4 @@ init_args: - chebai_graph.preprocessing.properties.AugBondType - chebai_graph.preprocessing.properties.AugBondInRing - chebai_graph.preprocessing.properties.AugBondAromaticity - - chebai_graph.preprocessing.properties.AugRDKit2DNormalized + #- chebai_graph.preprocessing.properties.AugRDKit2DNormalized From b68b5202a8488e60a306d729523c44f78045396e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 15 Jul 2025 12:44:09 +0200 Subject: [PATCH 170/224] remove mol props comment --- configs/data/chebi50_augmented_baseline.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/data/chebi50_augmented_baseline.yml b/configs/data/chebi50_augmented_baseline.yml index 91be1fa..e8520f3 100644 --- a/configs/data/chebi50_augmented_baseline.yml +++ b/configs/data/chebi50_augmented_baseline.yml @@ -10,4 +10,3 @@ init_args: - chebai_graph.preprocessing.properties.AugBondType - chebai_graph.preprocessing.properties.AugBondInRing - chebai_graph.preprocessing.properties.AugBondAromaticity - #- chebai_graph.preprocessing.properties.AugRDKit2DNormalized From 62e61d5eef49451fb53cd8d4f0a9854e88e689ee Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 15 Jul 2025 12:51:21 +0200 Subject: [PATCH 171/224] update model configs --- configs/model/gnn_resgated_aug.yml | 13 ------------- configs/model/resgated.yml | 4 ++-- 2 files changed, 2 insertions(+), 15 deletions(-) delete mode 100644 configs/model/gnn_resgated_aug.yml diff --git a/configs/model/gnn_resgated_aug.yml b/configs/model/gnn_resgated_aug.yml deleted file mode 100644 index d7869ca..0000000 --- a/configs/model/gnn_resgated_aug.yml +++ /dev/null @@ -1,13 +0,0 @@ -class_path: chebai_graph.models.ResGatedAugmentedGraphPred -init_args: - optimizer_kwargs: - lr: 1e-3 - config: - in_length: 256 - hidden_length: 512 - dropout_rate: 0.1 - n_conv_layers: 3 - n_linear_layers: 3 - n_atom_properties: 158 - n_bond_properties: 7 - n_molecule_properties: 200 diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml index 83746ae..581e742 100644 --- a/configs/model/resgated.yml +++ b/configs/model/resgated.yml @@ -5,9 +5,9 @@ init_args: config: in_length: 256 hidden_length: 512 - dropout_rate: 0.1 + dropout_rate: 0 n_conv_layers: 3 n_atom_properties: 158 n_bond_properties: 7 - n_molecule_properties: 200 + n_molecule_properties: 0 n_linear_layers: 2 From fb03f263cb4f5ddbbd46245cc501831a8afd2d68 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 16 Jul 2025 19:42:47 +0200 Subject: [PATCH 172/224] NO_FG instead to atom symbol to reduce no of tokens --- chebai_graph/preprocessing/reader/augmented_reader.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 59f15bd..e8d9040 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -477,18 +477,20 @@ def _set_fg_prop( ) if "" in fg_set and len(fg_set) == 1: + NO_FG = "NO_FG" if len(connected_atoms) == 1: # If there is only one atom and one edge connecting this atom to its fg_atom, # the functional group will be the symbol of this atom # This special case is to handle wildcard SMILES Eg. CHEBI:33429 atom = connected_atoms[0] - # TODO: needed or can we set to default fg prop `NO_FG`? - atom.SetProp("FG", atom.GetSymbol()) + # needed or can we set to default fg prop `NO_FG`? + # default to NO_FG, as very distinct atom symbols increases number of tokens + atom.SetProp("FG", NO_FG) else: # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name # assigned, Eg. CHEBI:55388, atom idx 2 and 3 ([C-]#[C-]") have no functional group name, so default FG prop is used for atom in connected_atoms: - atom.SetProp("FG", "NO_FG") + atom.SetProp("FG", NO_FG) # atom.SetProp("FG", fg_smiles) if len(fg_set - {""}) > 1: From f6b88e1f2c5c1b557dc5d9dc801a307355c7f76f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 16 Jul 2025 20:36:19 +0200 Subject: [PATCH 173/224] fix RING key error --- chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index a7748f8..0080270 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -38,6 +38,9 @@ def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: if mol is None: return + for atom in mol.GetAtoms(): + atom.SetProp("RING", "") + AllChem.GetSymmSSSR(mol) ######## SET RING PROP ######## From 331f9417e27803b2edea545b86c48c8a06573948 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 16 Jul 2025 20:57:54 +0200 Subject: [PATCH 174/224] fg detection algo - instead of atom symbol use NO_FG --- .../preprocessing/fg_detection/fg_aware_rule_based.py | 8 +++++--- chebai_graph/preprocessing/fg_detection/fg_constants.py | 2 ++ chebai_graph/preprocessing/reader/augmented_reader.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index 0080270..f3c4597 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -8,7 +8,7 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s -from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS +from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS, NO_FG def ring_size_processing(ring_size): @@ -1817,9 +1817,11 @@ def detect_functional_group(mol: Chem.Mol): ########################### Groups containing other elements ########################### if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: if charge == 0: - atom.SetProp("FG", atom_symbol) + # atom.SetProp("FG", atom_symbol) + atom.SetProp("FG", NO_FG) else: - atom.SetProp("FG", f"{atom_symbol}[{charge}]") + # atom.SetProp("FG", f"{atom_symbol}[{charge}]") + atom.SetProp("FG", NO_FG) else: pass diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py index 9b71e9f..eab62dc 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_constants.py +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -9,3 +9,5 @@ "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr" } # fmt: on + +NO_FG = "NO_FG" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index e8d9040..044fb30 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -9,6 +9,7 @@ from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure +from chebai_graph.preprocessing.fg_detection.fg_constants import NO_FG from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties import constants as k @@ -477,7 +478,6 @@ def _set_fg_prop( ) if "" in fg_set and len(fg_set) == 1: - NO_FG = "NO_FG" if len(connected_atoms) == 1: # If there is only one atom and one edge connecting this atom to its fg_atom, # the functional group will be the symbol of this atom From 42ef9a230fe9fe69c801c1af17ba0735a49ced96 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 16 Jul 2025 21:01:26 +0200 Subject: [PATCH 175/224] fix: 'int' object has no attribute 'split' --- chebai_graph/preprocessing/properties/augmented_properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index a0ce122..42d6aaa 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -212,7 +212,7 @@ def _check_modify_atom_prop_value( """ ring_size_str = self._get_atom_prop_value(atom, prop) if ring_size_str: - ring_sizes = list(map(int, ring_size_str.split("-"))) + ring_sizes = list(map(int, str(ring_size_str).split("-"))) # TODO: Decide ring size for atoms belongs to fused rings, rn only max ring size taken return max(ring_sizes) else: From b6a4ab3e923f2b4fbe013fe26c95f3d5519f7ba6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 16 Jul 2025 23:32:42 +0200 Subject: [PATCH 176/224] add alkyl prop to graph node --- chebai_graph/preprocessing/reader/augmented_reader.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 044fb30..d32d195 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -704,7 +704,12 @@ def _construct_nodes_to_graph_node_structure( - Graph-level node attributes. - Edge attributes for graph-level connections. """ - graph_node = {k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, "FG": "graph_fg", "RING": "0"} + graph_node = { + k.NODE_LEVEL: k.GRAPH_NODE_LEVEL, + "FG": "graph_fg", + "RING": "0", + "is_alkyl": "0", + } graph_to_nodes_edges = {} graph_edge_index = [[], []] From c957d51ec0a04db3a30ce5b9d57e057821a0ce0d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 18 Jul 2025 13:53:35 +0200 Subject: [PATCH 177/224] Revert "fg detection algo - instead of atom symbol use NO_FG" This reverts commit 331f9417e27803b2edea545b86c48c8a06573948. --- .../preprocessing/fg_detection/fg_aware_rule_based.py | 8 +++----- chebai_graph/preprocessing/fg_detection/fg_constants.py | 2 -- chebai_graph/preprocessing/reader/augmented_reader.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index f3c4597..0080270 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -8,7 +8,7 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s -from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS, NO_FG +from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS def ring_size_processing(ring_size): @@ -1817,11 +1817,9 @@ def detect_functional_group(mol: Chem.Mol): ########################### Groups containing other elements ########################### if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: if charge == 0: - # atom.SetProp("FG", atom_symbol) - atom.SetProp("FG", NO_FG) + atom.SetProp("FG", atom_symbol) else: - # atom.SetProp("FG", f"{atom_symbol}[{charge}]") - atom.SetProp("FG", NO_FG) + atom.SetProp("FG", f"{atom_symbol}[{charge}]") else: pass diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py index eab62dc..9b71e9f 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_constants.py +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -9,5 +9,3 @@ "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr" } # fmt: on - -NO_FG = "NO_FG" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index d32d195..054be0e 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -9,7 +9,6 @@ from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure -from chebai_graph.preprocessing.fg_detection.fg_constants import NO_FG from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties import constants as k @@ -478,6 +477,7 @@ def _set_fg_prop( ) if "" in fg_set and len(fg_set) == 1: + NO_FG = "NO_FG" if len(connected_atoms) == 1: # If there is only one atom and one edge connecting this atom to its fg_atom, # the functional group will be the symbol of this atom From 209726359c526694eb9026cbd63fadc0e505411a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 18 Jul 2025 14:02:53 +0200 Subject: [PATCH 178/224] fix: changed number of augmented nodes (rollback to original) --- .../preprocessing/fg_detection/fg_aware_rule_based.py | 7 +++++-- chebai_graph/preprocessing/reader/augmented_reader.py | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index 0080270..870aea0 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -34,7 +34,6 @@ def find_connected_rings(ring, remaining_rings) -> list[set[int]]: def set_ring_properties(mol: Chem.Mol) -> list[list[set[int]]] | None: - if mol is None: return @@ -1818,8 +1817,12 @@ def detect_functional_group(mol: Chem.Mol): if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: if charge == 0: atom.SetProp("FG", atom_symbol) + atom.SetProp("flag_no_fg", "") + # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) else: atom.SetProp("FG", f"{atom_symbol}[{charge}]") + atom.SetProp("flag_no_fg", "") + # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) else: pass @@ -1923,7 +1926,7 @@ def get_structure(mol): flat_atoms = set().union(*group) if flat_atoms.issubset(atom_idx) and len(flat_atoms) == len(atom_idx): for idx, ring_atoms in enumerate(group): - structure[f"{frag}_{idx+1}"] = { + structure[f"{frag}_{idx + 1}"] = { "atom": ring_atoms, "is_ring_fg": True, } diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 054be0e..0d30cf6 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -470,7 +470,11 @@ def _set_fg_prop( ValueError: If functional group assignment is inconsistent or missing. AssertionError: If no representative atom is found. """ - fg_set = {atom.GetProp("FG") for atom in connected_atoms} + fg_set = { + atom.GetProp("FG") + for atom in connected_atoms + if not atom.HasProp("is_no_fg") + } if not fg_set: raise ValueError( "No functional group assigned to atoms in the functional group." From 52cc9012356266643a7a4b8819c3056f8403647a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 18 Jul 2025 16:28:08 +0200 Subject: [PATCH 179/224] refactor set fg properties logic --- .../fg_detection/fg_aware_rule_based.py | 6 +- .../fg_detection/fg_constants.py | 3 + .../preprocessing/reader/augmented_reader.py | 89 ++++++++++++------- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py index 870aea0..f60f580 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +++ b/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py @@ -8,7 +8,7 @@ from rdkit.Chem import AllChem from rdkit.Chem import MolToSmiles as m2s -from chebai_graph.preprocessing.fg_detection.fg_constants import ELEMENTS +from .fg_constants import ELEMENTS, FLAG_NO_FG def ring_size_processing(ring_size): @@ -1817,11 +1817,11 @@ def detect_functional_group(mol: Chem.Mol): if atom.GetProp("FG") == "" and atom_symbol in ELEMENTS and not in_ring: if charge == 0: atom.SetProp("FG", atom_symbol) - atom.SetProp("flag_no_fg", "") + atom.SetProp(FLAG_NO_FG, "") # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) else: atom.SetProp("FG", f"{atom_symbol}[{charge}]") - atom.SetProp("flag_no_fg", "") + atom.SetProp(FLAG_NO_FG, "") # atom.SetProp("FG", NO_FG) # changes the fg-detection algo (num of fg dectected reduces) else: pass diff --git a/chebai_graph/preprocessing/fg_detection/fg_constants.py b/chebai_graph/preprocessing/fg_detection/fg_constants.py index 9b71e9f..5431264 100644 --- a/chebai_graph/preprocessing/fg_detection/fg_constants.py +++ b/chebai_graph/preprocessing/fg_detection/fg_constants.py @@ -9,3 +9,6 @@ "Tb", "Tc", "Te", "Th", "Ti", "Tl", "Tm", "U", "V", "W", "Xe", "Y", "Yb", "Zn", "Zr" } # fmt: on + + +FLAG_NO_FG = "flag_no_fg" diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 0d30cf6..4d0c22e 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -9,6 +9,7 @@ from chebai_graph.preprocessing.collate import GraphCollator from chebai_graph.preprocessing.fg_detection.fg_aware_rule_based import get_structure +from chebai_graph.preprocessing.fg_detection.fg_constants import FLAG_NO_FG from chebai_graph.preprocessing.properties import MolecularProperty from chebai_graph.preprocessing.properties import constants as k @@ -470,54 +471,74 @@ def _set_fg_prop( ValueError: If functional group assignment is inconsistent or missing. AssertionError: If no representative atom is found. """ - fg_set = { - atom.GetProp("FG") - for atom in connected_atoms - if not atom.HasProp("is_no_fg") - } + NO_FG = "NO_FG" + representative_atom = None + + # Check if the functional group SMILES corresponds to an alkyl group + # by removing common alkyl characters and checking if anything remains. + check = re.sub(r"[CH\-\(\)\[\]/\\@]", "", fg_smiles) + is_alkyl = "1" if len(check) == 0 else "0" + + fg_set = set() + for atom in connected_atoms: + atom.SetProp("is_alkyl", is_alkyl) + + # Set FG to NO_FG if this atom's fg is marked to be ignored + if atom.HasProp(FLAG_NO_FG): + atom.SetProp("FG", NO_FG) + + fg = atom.GetProp("FG") + fg_set.add(fg) + + # Store the last seen valid FG atom as representative + if fg and fg != NO_FG: + representative_atom = atom + + # Raise error if no FG at all was found (likely unexpected state) if not fg_set: raise ValueError( "No functional group assigned to atoms in the functional group." ) - if "" in fg_set and len(fg_set) == 1: - NO_FG = "NO_FG" - if len(connected_atoms) == 1: - # If there is only one atom and one edge connecting this atom to its fg_atom, - # the functional group will be the symbol of this atom - # This special case is to handle wildcard SMILES Eg. CHEBI:33429 - atom = connected_atoms[0] - # needed or can we set to default fg prop `NO_FG`? - # default to NO_FG, as very distinct atom symbols increases number of tokens + # Determine how many valid functional groups are present + valid_fgs = fg_set - {"", NO_FG} + num_of_valid_fgs = len(valid_fgs) + + if num_of_valid_fgs == 0: + # fg_set = {"", NO_FG} or {""} or {NO_FG} + for atom in connected_atoms: atom.SetProp("FG", NO_FG) - else: - # If there are multiple atoms connected to the functional group, and no atoms have a functional group property/name - # assigned, Eg. CHEBI:55388, atom idx 2 and 3 ([C-]#[C-]") have no functional group name, so default FG prop is used - for atom in connected_atoms: - atom.SetProp("FG", NO_FG) - # atom.SetProp("FG", fg_smiles) + node_fg = NO_FG - if len(fg_set - {""}) > 1: + elif num_of_valid_fgs > 1: + # fg_set = {"FG1", "FG2", ...} or {"FG1", "FG2", ..., NO_FG} or + # {"FG1", "FG2", ..., ""} or {"FG1", FG2, ..., "", NO_FG} + # Inconsistent FG assignments; possibly a bug in FG detection raise ValueError( - "Connected atoms have different function groups assigned.\n" - "All Connected atoms must belong to one functional group or None" + "Connected atoms have different functional groups assigned.\n" + "All connected atoms must belong to one functional group or None." ) - check = re.sub(r"[CH\-\(\)\[\]/\\@]", "", fg_smiles) - is_alkyl = "1" if len(check) == 0 else "0" - - representative_atom = None - for atom in connected_atoms: - if atom.GetProp("FG"): - representative_atom = atom - atom.SetProp("is_alkyl", is_alkyl) + elif num_of_valid_fgs == 1: + # fg_set = {"FG1"} or {"FG1", ""} or {"FG1", NO_FG} or {"FG1", "", NO_FG} + # Exactly one valid FG; ensure we have an atom to extract it from + if representative_atom is None: + raise AssertionError( + "Expected at least one atom with a valid functional group." + ) + node_fg = representative_atom.GetProp("FG") + # If any atom had FG as an empty string (""), backfill it with node_fg + for atom in connected_atoms: + atom.SetProp("FG", node_fg) - if representative_atom is None: - raise AssertionError("Expected at least one atom with a functional group.") + else: + # This branch is unreachable but kept for safety + raise AssertionError("Unexpected state in functional group detection.") + # Assign the final FG node metadata fg_nodes[self._idx_of_node] = { k.NODE_LEVEL: k.FG_NODE_LEVEL, - "FG": representative_atom.GetProp("FG"), + "FG": node_fg, "RING": 0, "is_alkyl": is_alkyl, } From 342cb85836dc5cd444c3af75601a2cf3ff728294 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 18 Jul 2025 16:29:35 +0200 Subject: [PATCH 180/224] atomring size not required as its included as functional group --- configs/data/chebi50_aug_all_props.yml | 1 - configs/data/chebi50_aug_props_only.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/configs/data/chebi50_aug_all_props.yml b/configs/data/chebi50_aug_all_props.yml index 160df0e..f8bc84f 100644 --- a/configs/data/chebi50_aug_all_props.yml +++ b/configs/data/chebi50_aug_all_props.yml @@ -10,7 +10,6 @@ init_args: - chebai_graph.preprocessing.properties.AugAtomNumHs - chebai_graph.preprocessing.properties.AtomFunctionalGroup - chebai_graph.preprocessing.properties.AtomNodeLevel - - chebai_graph.preprocessing.properties.AtomRingSize - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - chebai_graph.preprocessing.properties.IsFGAlkyl diff --git a/configs/data/chebi50_aug_props_only.yml b/configs/data/chebi50_aug_props_only.yml index a81e7f1..d81d303 100644 --- a/configs/data/chebi50_aug_props_only.yml +++ b/configs/data/chebi50_aug_props_only.yml @@ -4,7 +4,6 @@ init_args: # Atom properties - chebai_graph.preprocessing.properties.AtomFunctionalGroup - chebai_graph.preprocessing.properties.AtomNodeLevel - - chebai_graph.preprocessing.properties.AtomRingSize - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - chebai_graph.preprocessing.properties.IsFGAlkyl From 5bb5f0fb325e97d6548ff5afef3dc5f9bc0bdf48 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Jul 2025 00:39:44 +0200 Subject: [PATCH 181/224] restructure properties module --- .../preprocessing/properties/__init__.py | 24 +- .../properties/augmented_properties.py | 271 +----------------- chebai_graph/preprocessing/properties/base.py | 248 +++++++++++++++- .../preprocessing/properties/properties.py | 10 +- 4 files changed, 280 insertions(+), 273 deletions(-) diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 3ca40a9..9b6b393 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -3,7 +3,15 @@ # This is because augmented properties module imports from properties module # isort: off -from .base import MolecularProperty, AtomProperty, BondProperty +from .base import ( + MolecularProperty, + AtomProperty, + BondProperty, + MoleculeProperty, + AllNodeTypeProperty, + AtomNodeTypeProperty, + FGNodeTypeProperty, +) from .properties import ( AtomType, @@ -16,14 +24,12 @@ BondAromaticity, BondType, BondInRing, - MoleculeNumRings, RDKit2DNormalized, ) from .augmented_properties import ( AtomNodeLevel, AtomFunctionalGroup, - AtomRingSize, IsHydrogenBondDonorFG, IsHydrogenBondAcceptorFG, IsFGAlkyl, @@ -37,15 +43,20 @@ AugBondAromaticity, AugBondType, AugBondInRing, - AugRDKit2DNormalized, ) # isort: on __all__ = [ + # -------------- Properties Base classes -------------- "MolecularProperty", + "MoleculeProperty", "AtomProperty", "BondProperty", + "AllNodeTypeProperty", + "AtomNodeTypeProperty", + "FGNodeTypeProperty", + # -------------- Regular Properties ----------------- "AtomType", "NumAtomBonds", "AtomCharge", @@ -56,12 +67,10 @@ "BondAromaticity", "BondType", "BondInRing", - "MoleculeNumRings", "RDKit2DNormalized", - # -------- Augmented Molecular Properties -------- + # -------- Augmented Molecular Properties ---------- "AtomNodeLevel", "AtomFunctionalGroup", - "AtomRingSize", "IsHydrogenBondDonorFG", "IsHydrogenBondAcceptorFG", "IsFGAlkyl", @@ -75,5 +84,4 @@ "AugBondAromaticity", "AugBondType", "AugBondInRing", - "AugRDKit2DNormalized", ] diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index 42d6aaa..b08d36b 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -1,4 +1,3 @@ -import sys from abc import ABC from rdkit import Chem @@ -11,125 +10,18 @@ from . import constants as k from . import properties as pr -from .base import AtomProperty, BondProperty, FrozenPropertyAlias - -# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order -# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights -# https://mail.python.org/pipermail/python-dev/2017-December/151283.html -assert sys.version_info >= ( - 3, - 7, -), "This code requires Python 3.7 or higher." -# Order preservation is necessary to to create `prop_list` - +from .base import ( + AllNodeTypeProperty, + AtomNodeTypeProperty, + AugmentedBondProperty, + FGNodeTypeProperty, + FrozenPropertyAlias, +) # --------------------- Atom Properties ----------------------------- -class AugmentedAtomProperty(AtomProperty, ABC): - MAIN_KEY = "nodes" - - def get_property_value(self, augmented_mol: dict) -> list: - """ - Extract property values for atoms from the augmented molecule dictionary. - - Args: - augmented_mol (dict): Dictionary representing the augmented molecule. - - Raises: - KeyError: If required keys are missing in the dictionary. - TypeError: If types of contained objects are incorrect. - AssertionError: If the number of property values does not match number of nodes. - - Returns: - list: List of property values for all atoms, functional groups, and graph nodes. - """ - if self.MAIN_KEY not in augmented_mol: - raise KeyError( - f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" - ) - missing_keys = {"atom_nodes"} - augmented_mol[self.MAIN_KEY].keys() - if missing_keys: - raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - - atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] - if not isinstance(atom_molecule, Chem.Mol): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' - ) - prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] - - if "fg_nodes" in augmented_mol[self.MAIN_KEY]: - fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] - if not isinstance(fg_nodes, dict): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]) must be an instance of dict ' - f"containing its properties" - ) - prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) - - if "graph_node" in augmented_mol[self.MAIN_KEY]: - graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] - if not isinstance(graph_node, dict): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"](["graph_node"]) must be an instance of dict ' - f"containing its properties" - ) - prop_list.append(self.get_atom_value(graph_node)) - - assert ( - len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] - ), "Number of property values should be equal to number of nodes" - return prop_list - - def _check_modify_atom_prop_value( - self, atom: Chem.rdchem.Atom | dict, prop: str - ) -> str | int | bool: - """ - Check that the property value for the atom/node exists and is not empty. - - Args: - atom (Chem.rdchem.Atom | dict): Atom or node representation. - prop (str): Property name. - - Raises: - ValueError: If the property is empty. - Returns: - str | int | bool: The property value. - """ - value = self._get_atom_prop_value(atom, prop) - if not value: - # Every atom/node should have given value - raise ValueError(f"'{prop}' is set but empty.") - return value - - def _get_atom_prop_value( - self, atom: Chem.rdchem.Atom | dict, prop: str - ) -> str | int | bool: - """ - Retrieve a property value from an atom or dict node. - - Args: - atom (Chem.rdchem.Atom | dict): Atom or node. - prop (str): Property name. - - Raises: - TypeError: If atom is not an expected type. - - Returns: - str | int | bool: The property value. - """ - if isinstance(atom, Chem.rdchem.Atom): - return atom.GetProp(prop) - elif isinstance(atom, dict): - return atom[prop] - else: - raise TypeError( - f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." - ) - - -class AtomNodeLevel(AugmentedAtomProperty): +class AtomNodeLevel(AllNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Initialize AtomNodeLevel with an optional encoder. @@ -152,7 +44,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: return self._check_modify_atom_prop_value(atom, k.NODE_LEVEL) -class AtomFunctionalGroup(AugmentedAtomProperty): +class AtomFunctionalGroup(FGNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Initialize AtomFunctionalGroup with an optional encoder. @@ -175,7 +67,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> str | int | bool: return self._check_modify_atom_prop_value(atom, "FG") -class AtomRingSize(AugmentedAtomProperty): +class AtomRingSize(FGNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Initialize AtomRingSize with an optional encoder. @@ -219,7 +111,7 @@ def _check_modify_atom_prop_value( return 0 -class IsHydrogenBondDonorFG(AugmentedAtomProperty): +class IsHydrogenBondDonorFG(FGNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Initialize IsHydrogenBondDonorFG with an optional encoder. @@ -252,7 +144,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: return fg in self._hydrogen_bond_donor -class IsHydrogenBondAcceptorFG(AugmentedAtomProperty): +class IsHydrogenBondAcceptorFG(FGNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Initialize IsHydrogenBondAcceptorFG with an optional encoder. @@ -286,7 +178,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> bool: return fg in self._hydrogen_bond_acceptor -class IsFGAlkyl(AugmentedAtomProperty): +class IsFGAlkyl(FGNodeTypeProperty): def __init__(self, encoder: PropertyEncoder | None = None): """ Args: @@ -308,7 +200,7 @@ def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int: return int(self._check_modify_atom_prop_value(atom, "is_alkyl")) -class AugNodeValueDefaulter(AugmentedAtomProperty, FrozenPropertyAlias, ABC): +class AugNodeValueDefaulter(AtomNodeTypeProperty, FrozenPropertyAlias, ABC): def get_atom_value(self, atom: Chem.rdchem.Atom | dict) -> int | None: """ Get the property value for an atom or dict node. @@ -427,121 +319,6 @@ class AugAtomAromaticity(AugNodeValueDefaulter, pr.AtomAromaticity): # --------------------- Bond Properties ------------------------------ -class AugmentedBondProperty(BondProperty, ABC): - MAIN_KEY = "edges" - - def get_property_value(self, augmented_mol: dict) -> list: - """ - Get bond property values from augmented molecule dict. - - Args: - augmented_mol (dict): Augmented molecule dictionary containing edges. - - Returns: - list: List of property values for bonds in the augmented molecule. - - Raises: - KeyError: If required keys are missing in augmented_mol. - TypeError: If the expected objects are not of correct types. - AssertionError: If number of property values does not match expected edge count. - """ - if self.MAIN_KEY not in augmented_mol: - raise KeyError( - f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" - ) - - missing_keys = {k.WITHIN_ATOMS_EDGE} - augmented_mol[self.MAIN_KEY].keys() - if missing_keys: - raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") - - atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][k.WITHIN_ATOMS_EDGE] - if not isinstance(atom_molecule, Chem.Mol): - raise TypeError( - f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' - ) - prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] - - if k.ATOM_FG_EDGE in augmented_mol[self.MAIN_KEY]: - fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] - if not isinstance(fg_atom_edges, dict): - raise TypeError( - f"augmented_mol['{self.MAIN_KEY}'](['{k.ATOM_FG_EDGE}'])" - f"must be an instance of dict containing its properties" - ) - prop_list.extend( - [self.get_bond_value(bond) for bond in fg_atom_edges.values()] - ) - - if k.WITHIN_FG_EDGE in augmented_mol[self.MAIN_KEY]: - fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] - if not isinstance(fg_edges, dict): - raise TypeError( - f"augmented_mol['{self.MAIN_KEY}'](['{k.WITHIN_FG_EDGE}'])" - f"must be an instance of dict containing its properties" - ) - prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) - - if k.TO_GRAPHNODE_EDGE in augmented_mol[self.MAIN_KEY]: - fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.TO_GRAPHNODE_EDGE] - if not isinstance(fg_graph_node_edges, dict): - raise TypeError( - f"augmented_mol['{self.MAIN_KEY}'](['{k.TO_GRAPHNODE_EDGE}'])" - f"must be an instance of dict containing its properties" - ) - prop_list.extend( - [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] - ) - - num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 - assert ( - len(prop_list) == num_directed_edges - ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " - - return prop_list - - def _check_modify_bond_prop_value( - self, bond: Chem.rdchem.Bond | dict, prop: str - ) -> str: - """ - Helper to check and get bond property value. - - Args: - bond (Chem.rdchem.Bond | dict): Bond object or bond property dict. - prop (str): Property key to get. - - Returns: - str: Property value. - - Raises: - ValueError: If value is empty or falsy. - """ - value = self._get_bond_prop_value(bond, prop) - if not value: - # Every atom/node should have given value - raise ValueError(f"'{prop}' is set but empty.") - return value - - @staticmethod - def _get_bond_prop_value(bond: Chem.rdchem.Bond | dict, prop: str) -> str: - """ - Extract bond property value from bond or dict. - - Args: - bond (Chem.rdchem.Bond | dict): Bond object or dict. - prop (str): Property key. - - Returns: - str: Property value. - - Raises: - TypeError: If bond is not the expected type. - """ - if isinstance(bond, Chem.rdchem.Bond): - return bond.GetProp(prop) - elif isinstance(bond, dict): - return bond[prop] - else: - raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") class BondLevel(AugmentedBondProperty): @@ -625,23 +402,3 @@ class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): """ ... - - -# --------------------- Molecular Properties ------------------------------ -class AugmentedMolecularProperty(pr.MolecularProperty, ABC): - def get_property_value(self, augmented_mol: dict) -> list: - """ - Get molecular property values from augmented molecule dict. - - Args: - augmented_mol (dict): Augmented molecule dict. - - Returns: - list: Property values of molecule. - """ - mol: Chem.Mol = augmented_mol[AugmentedAtomProperty.MAIN_KEY]["atom_nodes"] - assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" - return super().get_property_value(mol) - - -class AugRDKit2DNormalized(AugmentedMolecularProperty, pr.RDKit2DNormalized): ... diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index 501a340..e480b84 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -1,3 +1,4 @@ +import sys from abc import ABC, abstractmethod from types import MappingProxyType @@ -5,6 +6,17 @@ from chebai_graph.preprocessing.property_encoder import IndexEncoder, PropertyEncoder +from . import constants as k + +# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order +# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights +# https://mail.python.org/pipermail/python-dev/2017-December/151283.html +assert sys.version_info >= ( + 3, + 7, +), "This code requires Python 3.7 or higher." +# Order preservation is necessary to to create `prop_list`in Augmented properties + class MolecularProperty(ABC): """ @@ -131,7 +143,7 @@ def get_bond_value(self, bond: Chem.rdchem.Bond) -> object: pass -class MoleculeProperty(MolecularProperty): +class MoleculeProperty(MolecularProperty, ABC): """ Class representing a global (molecule-level) property. @@ -155,8 +167,7 @@ class FrozenPropertyAlias(MolecularProperty, ABC): Example: ```python - class AugAtomType(FrozenPropertyAlias, AtomType): - ... + class AugAtomType(FrozenPropertyAlias, AtomType): ... ``` Raises: @@ -197,3 +208,234 @@ def on_finish(self) -> None: f"to a frozen encoder at {self.encoder.index_path}" ) super().on_finish() + + +class AugmentedAtomProperty(AtomProperty, ABC): + MAIN_KEY = "nodes" + + def get_property_value(self, augmented_mol: dict) -> list: + """ + Extract property values for atoms from the augmented molecule dictionary. + + Args: + augmented_mol (dict): Dictionary representing the augmented molecule. + + Raises: + KeyError: If required keys are missing in the dictionary. + TypeError: If types of contained objects are incorrect. + AssertionError: If the number of property values does not match number of nodes. + + Returns: + list: List of property values for all atoms, functional groups, and graph nodes. + """ + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = {"atom_nodes"} - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY]["atom_nodes"] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol' + ) + prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()] + + if "fg_nodes" in augmented_mol[self.MAIN_KEY]: + fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"] + if not isinstance(fg_nodes, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()]) + + if "graph_node" in augmented_mol[self.MAIN_KEY]: + graph_node = augmented_mol[self.MAIN_KEY]["graph_node"] + if not isinstance(graph_node, dict): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"](["graph_node"]) must be an instance of dict ' + f"containing its properties" + ) + prop_list.append(self.get_atom_value(graph_node)) + + assert ( + len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"] + ), "Number of property values should be equal to number of nodes" + return prop_list + + def _check_modify_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Check that the property value for the atom/node exists and is not empty. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node representation. + prop (str): Property name. + + Raises: + ValueError: If the property is empty. + + Returns: + str | int | bool: The property value. + """ + value = self._get_atom_prop_value(atom, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + def _get_atom_prop_value( + self, atom: Chem.rdchem.Atom | dict, prop: str + ) -> str | int | bool: + """ + Retrieve a property value from an atom or dict node. + + Args: + atom (Chem.rdchem.Atom | dict): Atom or node. + prop (str): Property name. + + Raises: + TypeError: If atom is not an expected type. + + Returns: + str | int | bool: The property value. + """ + if isinstance(atom, Chem.rdchem.Atom): + return atom.GetProp(prop) + elif isinstance(atom, dict): + return atom[prop] + else: + raise TypeError( + f"Atom/Node in key `{self.MAIN_KEY}` should be of type `Chem.rdchem.Atom` or `dict`." + ) + + +class AtomNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class FGNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class AllNodeTypeProperty(AugmentedAtomProperty, ABC): ... + + +class AugmentedBondProperty(BondProperty, ABC): + MAIN_KEY = "edges" + + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get bond property values from augmented molecule dict. + + Args: + augmented_mol (dict): Augmented molecule dictionary containing edges. + + Returns: + list: List of property values for bonds in the augmented molecule. + + Raises: + KeyError: If required keys are missing in augmented_mol. + TypeError: If the expected objects are not of correct types. + AssertionError: If number of property values does not match expected edge count. + """ + if self.MAIN_KEY not in augmented_mol: + raise KeyError( + f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict" + ) + + missing_keys = {k.WITHIN_ATOMS_EDGE} - augmented_mol[self.MAIN_KEY].keys() + if missing_keys: + raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes") + + atom_molecule: Chem.Mol = augmented_mol[self.MAIN_KEY][k.WITHIN_ATOMS_EDGE] + if not isinstance(atom_molecule, Chem.Mol): + raise TypeError( + f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol' + ) + prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()] + + if k.ATOM_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE] + if not isinstance(fg_atom_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.ATOM_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_atom_edges.values()] + ) + + if k.WITHIN_FG_EDGE in augmented_mol[self.MAIN_KEY]: + fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE] + if not isinstance(fg_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.WITHIN_FG_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()]) + + if k.TO_GRAPHNODE_EDGE in augmented_mol[self.MAIN_KEY]: + fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.TO_GRAPHNODE_EDGE] + if not isinstance(fg_graph_node_edges, dict): + raise TypeError( + f"augmented_mol['{self.MAIN_KEY}'](['{k.TO_GRAPHNODE_EDGE}'])" + f"must be an instance of dict containing its properties" + ) + prop_list.extend( + [self.get_bond_value(bond) for bond in fg_graph_node_edges.values()] + ) + + num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2 + assert ( + len(prop_list) == num_directed_edges + ), f"Number of property values ({len(prop_list)}) should be equal to number of half the number of undirected edges i.e. must be equal to {num_directed_edges} " + + return prop_list + + def _check_modify_bond_prop_value( + self, bond: Chem.rdchem.Bond | dict, prop: str + ) -> str: + """ + Helper to check and get bond property value. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or bond property dict. + prop (str): Property key to get. + + Returns: + str: Property value. + + Raises: + ValueError: If value is empty or falsy. + """ + value = self._get_bond_prop_value(bond, prop) + if not value: + # Every atom/node should have given value + raise ValueError(f"'{prop}' is set but empty.") + return value + + @staticmethod + def _get_bond_prop_value(bond: Chem.rdchem.Bond | dict, prop: str) -> str: + """ + Extract bond property value from bond or dict. + + Args: + bond (Chem.rdchem.Bond | dict): Bond object or dict. + prop (str): Property key. + + Returns: + str: Property value. + + Raises: + TypeError: If bond is not the expected type. + """ + if isinstance(bond, Chem.rdchem.Bond): + return bond.GetProp(prop) + elif isinstance(bond, dict): + return bond[prop] + else: + raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index ccf869c..19e8a4e 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -9,7 +9,7 @@ PropertyEncoder, ) -from .base import AtomProperty, BondProperty, MolecularProperty +from .base import AtomProperty, BondProperty, MoleculeProperty class AtomType(AtomProperty): @@ -242,7 +242,7 @@ def get_bond_value(self, bond: Chem.rdchem.Bond) -> bool: return bond.IsInRing() -class MoleculeNumRings(MolecularProperty): +class MoleculeNumRings(MoleculeProperty): """ Molecule-level property representing the number of rings in the molecule. @@ -265,7 +265,7 @@ def get_property_value(self, mol: Chem.rdchem.Mol) -> list[int]: return [mol.GetRingInfo().NumRings()] -class RDKit2DNormalized(MolecularProperty): +class RDKit2DNormalized(MoleculeProperty): """ Molecule-level property representing normalized 2D descriptors from RDKit. @@ -274,6 +274,7 @@ class RDKit2DNormalized(MolecularProperty): def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or AsIsEncoder(self)) + self.generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]: """ @@ -285,8 +286,7 @@ def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]: Returns: list[np.ndarray]: List containing the descriptor numpy array (excluding first element). """ - generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() - features_normalized = generator_normalized.processMol( + features_normalized = self.generator_normalized.processMol( mol, Chem.MolToSmiles(mol) ) features_normalized = np.nan_to_num(features_normalized) From c00b65c3da8a5157f5559291cde24166941ce7a4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Jul 2025 00:40:16 +0200 Subject: [PATCH 182/224] add class for props as per node type --- chebai_graph/preprocessing/datasets/chebi.py | 158 ++++++++++++++++++- 1 file changed, 154 insertions(+), 4 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index bf148d2..c4db4c2 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -15,9 +15,13 @@ from torch_geometric.data.data import Data as GeomData from chebai_graph.preprocessing.properties import ( + AllNodeTypeProperty, + AtomNodeTypeProperty, AtomProperty, BondProperty, + FGNodeTypeProperty, MolecularProperty, + MoleculeProperty, ) from chebai_graph.preprocessing.reader import ( AtomFGReader_NoFGEdges_WithGraphNode, @@ -41,7 +45,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) -class GraphPropertiesMixIn(ChEBIOverX, ABC): +class DataPropertiesSetter(ChEBIOverX, ABC): """Mixin for adding molecular property encodings to graph-based ChEBI datasets.""" READER = GraphPropertyReader @@ -172,6 +176,8 @@ def _after_setup(self, **kwargs) -> None: self._setup_properties() super()._after_setup(**kwargs) + +class GraphPropertiesMixIn(DataPropertiesSetter, ABC): def _merge_props_into_base(self, row: pd.Series) -> GeomData: """ Merge encoded molecular properties into the GeomData object. @@ -208,8 +214,10 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: [edge_attr, torch.cat([property_values, property_values], dim=0)], dim=1, ) - else: + elif isinstance(property, MoleculeProperty): molecule_attr = torch.cat([molecule_attr, property_values], dim=1) + else: + raise TypeError(f"Unsupported property type: {type(property).__name__}") return GeomData( x=x, @@ -261,11 +269,153 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" f"Use n_atom_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, " f"n_bond_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " - f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if not isinstance(p, (AtomProperty, BondProperty)))}" + f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" + ) + + return base_df[base_data[0].keys()].to_dict("records") + + +class GraphPropertiesAsPerNodeType(DataPropertiesSetter, ABC): + READER = AtomFGReader_WithFGEdges_WithGraphNode + + def load_processed_data_from_file(self, filename: str) -> list[dict]: + """ + Load dataset and merge cached properties into base features. + + Args: + filename: The path to the file to load. + + Returns: + List of data entries, each a dictionary. + """ + base_data = super().load_processed_data_from_file(filename) + base_df = pd.DataFrame(base_data) + + props_categories = { + "AllNodeTypeProperties": [], + "FGNodeTypeProperties": [], + "AtomNodeTypeProperties": [], + "GraphNodeTypeProperties": [], + "BondProperties": [], + } + n_atom_node_properties, n_fg_node_properties = 0, 0 + n_bond_properties, n_graph_node_properties = 0, 0 + prop_lengths = [] + for prop in self.properties: + prop_length = prop.encoder.get_encoding_length() + prop_name = prop.name + prop_lengths.append((prop_name, prop_length)) + if isinstance(prop, AllNodeTypeProperty): + n_atom_node_properties += prop_length + n_fg_node_properties += prop_length + props_categories["AllNodeTypeProperties"].append(prop_name) + elif isinstance(prop, FGNodeTypeProperty): + n_fg_node_properties += prop_length + props_categories["FGNodeTypeProperties"].append(prop_name) + elif isinstance(prop, AtomNodeTypeProperty): + n_atom_node_properties += prop_length + props_categories["AtomNodeTypeProperties"].append(prop_name) + elif isinstance(prop, BondProperty): + n_bond_properties += prop_length + props_categories["BondProperties"].append(prop_name) + elif isinstance(prop, MoleculeProperty): + # molecule props will be used as graph node props + n_graph_node_properties += prop_length + props_categories["GraphNodeTypeProperties"].append(prop_name) + else: + raise TypeError(f"Unsupported property type: {type(prop).__name__}") + + n_atom_properties = max( + n_atom_node_properties, n_fg_node_properties, n_graph_node_properties ) + rank_zero_info( + f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" + f"Properties Categories {props_categories}\n" + f"n_atom_node_properties: {n_atom_node_properties}, " + f"n_fg_node_properties: {n_fg_node_properties}, " + f"n_bond_properties: {n_bond_properties}, " + f"n_graph_node_properties: {n_graph_node_properties}\n" + f"Use n_atom_properties: {n_atom_properties}, n_bond_properties: {n_bond_properties}, n_molecule_properties: 0" + ) + + for property in self.properties: + property_data = torch.load( + self.get_property_path(property), weights_only=False + ) + if len(property_data[0][property.name].shape) > 1: + property.encoder.set_encoding_length( + property_data[0][property.name].shape[1] + ) + + property_df = pd.DataFrame(property_data) + property_df.rename( + columns={property.name: f"{property.name}"}, inplace=True + ) + base_df = base_df.merge(property_df, on="ident", how="left") + + base_df["features"] = base_df.apply( + lambda row: self._merge_props_into_base(row), axis=1 + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + base_df["features"] = base_df["features"].apply(self.transform) return base_df[base_data[0].keys()].to_dict("records") + def _merge_props_into_base(self, row: pd.Series) -> GeomData: + """ + Merge encoded molecular properties into the GeomData object. + + Args: + row: A dictionary containing 'features' and encoded properties. + + Returns: + A GeomData object with merged features. + """ + geom_data = row["features"] + assert isinstance(geom_data, GeomData) + is_atom_node = geom_data.is_atom_node + assert is_atom_node is not None, "`is_atom_node` must be set in the geom_data" + is_graph_node = geom_data.is_graph_node + assert is_graph_node is not None, "`is_graph_node` must be set in the geom_data" + + edge_attr = geom_data.edge_attr + x = geom_data.x + molecule_attr = torch.empty((1, 0)) + + for property in self.properties: + property_values = row[f"{property.name}"] + if isinstance(property_values, torch.Tensor): + if len(property_values.size()) == 0: + property_values = property_values.unsqueeze(0) + if len(property_values.size()) == 1: + property_values = property_values.unsqueeze(1) + else: + property_values = torch.zeros( + (0, property.encoder.get_encoding_length()) + ) + + if isinstance(property, AtomProperty): + x = torch.cat([x, property_values], dim=1) + elif isinstance(property, BondProperty): + # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges + edge_attr = torch.cat( + [edge_attr, torch.cat([property_values, property_values], dim=0)], + dim=1, + ) + elif isinstance(property, MoleculeProperty): + molecule_attr = torch.cat([molecule_attr, property_values], dim=1) + else: + raise TypeError(f"Unsupported property type: {type(property).__name__}") + + return GeomData( + x=x, + edge_index=geom_data.edge_index, + edge_attr=edge_attr, + molecule_attr=molecule_attr, + ) + class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50): """ChEBIOver50 dataset with molecular property encodings.""" @@ -310,7 +460,7 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: data = super()._merge_props_into_base(row) return self._add_graph_node_mask(data, row) - def _add_graph_node_mask(self, data: GeomData, row) -> GeomData: + def _add_graph_node_mask(self, data: GeomData, row: pd.Series) -> GeomData: """ Add a graph node mask to the GeomData object. From 280f71aa1455570ac3063fe1e5e515626e3a4cf7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Jul 2025 15:36:22 +0200 Subject: [PATCH 183/224] rearrange prop so that allnode prop type to be first --- configs/data/chebi50_aug_all_props.yml | 4 ++-- configs/data/chebi50_aug_prop_as_per_node.yml | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 configs/data/chebi50_aug_prop_as_per_node.yml diff --git a/configs/data/chebi50_aug_all_props.yml b/configs/data/chebi50_aug_all_props.yml index f8bc84f..47e75ad 100644 --- a/configs/data/chebi50_aug_all_props.yml +++ b/configs/data/chebi50_aug_all_props.yml @@ -2,6 +2,7 @@ class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp init_args: properties: # Atom properties + - chebai_graph.preprocessing.properties.AtomNodeLevel - chebai_graph.preprocessing.properties.AugAtomType - chebai_graph.preprocessing.properties.AugNumAtomBonds - chebai_graph.preprocessing.properties.AugAtomCharge @@ -9,12 +10,11 @@ init_args: - chebai_graph.preprocessing.properties.AugAtomHybridization - chebai_graph.preprocessing.properties.AugAtomNumHs - chebai_graph.preprocessing.properties.AtomFunctionalGroup - - chebai_graph.preprocessing.properties.AtomNodeLevel - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - chebai_graph.preprocessing.properties.IsFGAlkyl # Bond properties + - chebai_graph.preprocessing.properties.BondLevel - chebai_graph.preprocessing.properties.AugBondType - chebai_graph.preprocessing.properties.AugBondInRing - chebai_graph.preprocessing.properties.AugBondAromaticity - - chebai_graph.preprocessing.properties.BondLevel diff --git a/configs/data/chebi50_aug_prop_as_per_node.yml b/configs/data/chebi50_aug_prop_as_per_node.yml new file mode 100644 index 0000000..effb04a --- /dev/null +++ b/configs/data/chebi50_aug_prop_as_per_node.yml @@ -0,0 +1,24 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType +init_args: + properties: + # All Node type properties + - chebai_graph.preprocessing.properties.AtomNodeLevel + # Atom Node type properties + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds + - chebai_graph.preprocessing.properties.AugAtomCharge + - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomHybridization + - chebai_graph.preprocessing.properties.AugAtomNumHs + # FG Node type properties + - chebai_graph.preprocessing.properties.AtomFunctionalGroup + - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG + - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG + - chebai_graph.preprocessing.properties.IsFGAlkyl + # Graph Node type properties + - chebai_graph.preprocessing.properties.RDKit2DNormalized + # Bond properties + - chebai_graph.preprocessing.properties.BondLevel + - chebai_graph.preprocessing.properties.AugBondType + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondAromaticity From 6aa18d6b22cb7c5758d8b5152c554fc165197e07 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Jul 2025 15:38:05 +0200 Subject: [PATCH 184/224] add atom and bond level prop tokens --- .../preprocessing/bin/AtomNodeLevel/indices_one_hot.txt | 3 +++ chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt | 4 ++++ 2 files changed, 7 insertions(+) create mode 100644 chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt create mode 100644 chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt diff --git a/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt new file mode 100644 index 0000000..2d65776 --- /dev/null +++ b/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt @@ -0,0 +1,3 @@ +atom_node_lvl +fg_node_lvl +graph_node_level diff --git a/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt new file mode 100644 index 0000000..c5f7ed0 --- /dev/null +++ b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt @@ -0,0 +1,4 @@ +atom_fg_lvl +fg_graphNode_lvl +within_atoms_lvl +within_fg_lvl From f776af155a01c817346eaa0aab5c9394da261b5d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 20 Jul 2025 23:22:29 +0200 Subject: [PATCH 185/224] restructure resgated params - add pyg impl to compare --- chebai_graph/models/base.py | 17 ++--- chebai_graph/models/resgated.py | 131 +++++++++++++++++++++++++++----- configs/model/resgated.yml | 11 ++- 3 files changed, 126 insertions(+), 33 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 56eed31..5d5eb6f 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -52,16 +52,17 @@ def __init__(self, config: dict, **kwargs) -> None: - 'hidden_length' - 'dropout_rate' - 'n_conv_layers' - - 'n_atom_properties' + - 'n_node_properties' - 'n_bond_properties' **kwargs: Additional keyword arguments for torch.nn.Module. """ super().__init__(**kwargs) - self.hidden_length = int(config["hidden_length"]) - self.dropout_rate = float(config["dropout_rate"]) - self.n_conv_layers = int(config["n_conv_layers"]) - self.n_atom_properties = int(config["n_atom_properties"]) - self.n_bond_properties = int(config["n_bond_properties"]) + self.hidden_channels = int(config["hidden_channels"]) + self.out_channels = int(config["out_channels"]) + self.num_layers = int(config["num_layers"]) + assert self.num_layers > 1, "Need atleast two convolution layers" + self.n_node_properties = int(config["n_node_properties"]) # in_channels + self.n_bond_properties = int(config["n_bond_properties"]) # edge_dim class GraphNetWrapper(GraphBaseNet, ABC): @@ -83,9 +84,7 @@ def __init__( """ super().__init__(**kwargs) self.gnn = self._get_gnn(config) - gnn_out_dim = ( - config["out_dim"] if "out_dim" in config else config["hidden_length"] - ) + gnn_out_dim = int(config["out_channels"]) self.activation = torch.nn.ELU self.lin_input_dim = self._get_lin_seq_input_dim( gnn_out_dim=gnn_out_dim, diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py index 9244327..b07531d 100644 --- a/chebai_graph/models/resgated.py +++ b/chebai_graph/models/resgated.py @@ -1,8 +1,12 @@ +from typing import Final, Tuple, Union + import torch import torch.nn.functional as F -from torch import nn +from torch.nn import ELU from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.models.basic_gnn import BasicGNN from .base import GraphModelBase, GraphNetWrapper @@ -24,38 +28,36 @@ def __init__(self, config: dict, **kwargs): Args: config (dict): Configuration dictionary with keys: - - 'in_length' (int): Intermediate feature length used in GNN layers. + - 'hidden_length' (int): Intermediate feature length used in GNN layers. - Other parameters inherited from GraphModelBase. **kwargs: Additional keyword arguments passed to GraphModelBase. """ super().__init__(config=config, **kwargs) - self.in_length = int(config["in_length"]) self.activation = F.elu - self.dropout = nn.Dropout(self.dropout_rate) - self.convs = torch.nn.ModuleList() - for i in range(self.n_conv_layers): - if i == 0: - # Initial layer uses atom features as input - self.convs.append( - tgnn.ResGatedGraphConv( - self.n_atom_properties, - self.in_length, - # dropout=self.dropout_rate, - edge_dim=self.n_bond_properties, - ) - ) + self.convs.append( + tgnn.ResGatedGraphConv( + self.n_node_properties, + self.hidden_channels, + # dropout=self.dropout, + edge_dim=self.n_bond_properties, + ) + ) + + for _ in range(self.num_layers - 2): # Intermediate layers self.convs.append( tgnn.ResGatedGraphConv( - self.in_length, self.in_length, edge_dim=self.n_bond_properties + self.hidden_channels, + self.hidden_channels, + edge_dim=self.n_bond_properties, ) ) # Final projection layer to hidden dimension self.final_conv = tgnn.ResGatedGraphConv( - self.in_length, self.hidden_length, edge_dim=self.n_bond_properties + self.hidden_channels, self.out_channels, edge_dim=self.n_bond_properties ) def forward(self, batch: dict) -> torch.Tensor: @@ -109,3 +111,96 @@ def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase: ResGatedGraphConvNetBase: The GNN encoder. """ return ResGatedGraphConvNetBase(config=config) + + +class ResGatedModel(BasicGNN): + supports_edge_weight: Final[bool] = False + supports_edge_attr: Final[bool] = True + supports_norm_batch: Final[bool] + + def init_conv( + self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs + ) -> MessagePassing: + return tgnn.ResGatedGraphConv( + in_channels, + out_channels, + **kwargs, + ) + + +class ResGatedPyG(GraphModelBase): + """ + Graph Attention Network (GAT) base module for graph convolution. + + Uses PyTorch Geometric's `GAT` implementation to process atomic node features + and bond edge attributes through multiple attention heads and layers. + """ + + def __init__(self, config: dict, **kwargs): + """ + Initialize the GATGraphConvNetBase. + + Args: + config (dict): Model configuration containing: + - 'heads' (int): Number of attention heads. + - 'v2' (bool): Whether to use the GATv2 variant. + - Other required GraphModelBase parameters. + **kwargs: Additional arguments for the base class. + """ + super().__init__(config=config, **kwargs) + self.activation = ELU() # Instantiate ELU once for reuse. + self.gat = ResGatedModel( + in_channels=self.n_node_properties, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + edge_dim=self.n_bond_properties, + act=self.activation, + ) + + def forward(self, batch: dict) -> torch.Tensor: + """ + Forward pass through the GAT network. + + Processes atomic node features and edge attributes, and applies + an ELU activation to the output. + + Args: + batch (dict): Input batch containing: + - 'features': A list with a `GraphData` object as its first element. + + Returns: + torch.Tensor: Node embeddings after GAT and activation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData) + + out = self.gat( + x=graph_data.x.float(), + edge_index=graph_data.edge_index.long(), + edge_attr=graph_data.edge_attr, + ) + + return self.activation(out) + + +class ResGatedGraphPredPyG(GraphNetWrapper): + """ + Residual Gated GNN for Graph Prediction. + + Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings. + """ + + NAME = "ResGatedGraphPred" + + def _get_gnn(self, config: dict) -> ResGatedPyG: + """ + Instantiate the residual gated GNN backbone. + + Args: + config (dict): Model configuration. + + Returns: + ResGatedGraphConvNetBase: The GNN encoder. + """ + return ResGatedPyG(config=config) diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml index 581e742..6c6e4aa 100644 --- a/configs/model/resgated.yml +++ b/configs/model/resgated.yml @@ -3,11 +3,10 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - in_length: 256 - hidden_length: 512 - dropout_rate: 0 - n_conv_layers: 3 - n_atom_properties: 158 - n_bond_properties: 7 + n_node_properties: 68 # in_channels + hidden_channels : 256 + out_channels : 512 + num_layers : 4 + n_bond_properties: 4 # edge_dim n_molecule_properties: 0 n_linear_layers: 2 From 54b0d070ce856437533ba8169619223b0c6da1cc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 21 Jul 2025 16:40:33 +0200 Subject: [PATCH 186/224] add test for grap prop as per node class --- .../dataclasses/testGraphPropAsPerNodeType.py | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 tests/unit/dataclasses/testGraphPropAsPerNodeType.py diff --git a/tests/unit/dataclasses/testGraphPropAsPerNodeType.py b/tests/unit/dataclasses/testGraphPropAsPerNodeType.py new file mode 100644 index 0000000..df9a52f --- /dev/null +++ b/tests/unit/dataclasses/testGraphPropAsPerNodeType.py @@ -0,0 +1,151 @@ +import unittest + +import pandas as pd +import torch +from torch_geometric.data.data import Data as GeomData + +from chebai_graph.preprocessing.datasets.chebi import ChEBI50_WFGE_WGN_AsPerNodeType +from chebai_graph.preprocessing.properties import ( + AtomNodeLevel, + AugAtomCharge, + AugAtomHybridization, + BondLevel, + IsFGAlkyl, + IsHydrogenBondAcceptorFG, + RDKit2DNormalized, +) + + +class TestGraphPropAsPerNodeType(unittest.TestCase): + def test_merge_properties(self): + num_nodes = 4 + dummy_x = torch.zeros((num_nodes, 0)) # Initial dummy x + dummy_edge_index = torch.tensor([[0, 1], [1, 0]]) + dummy_edge_attr = torch.zeros((4, 0)) # 4 edges, each with 0 feature + + # Masks + is_atom_node = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + is_graph_node = torch.tensor([0, 0, 0, 1], dtype=torch.bool) + + # GeomData + geom_data = GeomData( + x=dummy_x, + edge_index=dummy_edge_index, + edge_attr=dummy_edge_attr, + is_atom_node=is_atom_node, + is_graph_node=is_graph_node, + ) + + # Define properties + # atom props = 5, fg_props = 4, graph_node_props = 6; max = 6 + all_node_prop = AtomNodeLevel(DummyEncoder(2)) + atom_prop = AugAtomCharge(DummyEncoder(1)) + atom_prop_2 = AugAtomHybridization(DummyEncoder(2)) + fg_prop = IsFGAlkyl(DummyEncoder(1)) + fg_prop_2 = IsHydrogenBondAcceptorFG(DummyEncoder(1)) + mol_prop = RDKit2DNormalized(DummyEncoder(4)) + bond_prop = BondLevel(DummyEncoder(2)) + + properties = [ + atom_prop, + atom_prop_2, + fg_prop, + fg_prop_2, + bond_prop, + all_node_prop, + mol_prop, + ] + + merger = ChEBI50_WFGE_WGN_AsPerNodeType(properties) + + # Define encoded property values for the row + row = pd.Series( + { + "features": geom_data, + "AtomNodeLevel": torch.tensor( + [ + [1.0, 0.0], # atom + [0.0, 1.0], # fg + [0.0, 0.0], # atom + [0.0, 1.0], # graph + ] + ), + "AtomCharge": torch.tensor( + [ + [1.0], # atom + [2.0], # fg + [6.0], # atom + [3.0], # graph + ] + ), + "AtomHybridization": torch.tensor( + [ + [11.0, 9.0], # atom + [7.0, 3.0], # fg + [3.0, 1.0], # atom + [7.0, 43.0], # graph + ] + ), + "IsFGAlkyl": torch.tensor( + [ + [5.0], # atom + [55.0], # fg + [13.0], # atom + [14.0], # graph + ] + ), # values for fg at idx 1 + "IsHydrogenBondAcceptorFG": torch.tensor( + [ + [3.0], # atom + [5.0], # fg + [17.0], # atom + [15.0], # graph + ] + ), + "RDKit2DNormalized": torch.tensor( + [ + [65.0, 23.0, 6.0, 8.0], # atom + [2.0, 8.0, 55.0, 77.0], # fg + [3.0, 51.0, 55.0, 3.0], # atom + [33.0, 6.0, 10.0, 10.0], # graph + ] + ), # only idx 3 + "BondLevel": torch.tensor( + [ + [0.1, 0.2], + [0.3, 0.4], + ] + ), # will be duplicated to 4x2 + } + ) + expected_result = torch.tensor( + [ + # all node # ap1 # ap2 # 0 concat + [1.0, 0.0] + [1.0] + [11.0, 9.0] + [0.0], # atom node + # all node # fg1 # fg2 # 0 concat + [0.0, 1.0] + [55.0] + [5.0] + [0.0, 0.0], # fg node + # all node # ap1 # ap2 # 0 concat + [0.0, 0.0] + [6.0] + [3.0, 1.0] + [0.0], # atom node + # all node # mol props + [0.0, 1.0] + [33.0, 6.0, 10.0, 10.0], # graph node + ] + ) + + result = merger._merge_props_into_base(row, max_len_node_properties=6) + self.assertTrue(torch.equal(result.x, expected_result)) + + +class DummyEncoder: + def __init__(self, length): + self.length = length + + @property + def name(self): + return self.__class__.__name__.replace("DummyEncoder", "") + + def get_encoding_length(self): + return self.length + + +if __name__ == "__main__": + unittest.main() From e4e5397c143a2788e303a02a2c18f9b893a9da67 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 21 Jul 2025 16:43:07 +0200 Subject: [PATCH 187/224] modify merge method to have diff props for diff nodes --- chebai_graph/preprocessing/datasets/chebi.py | 81 +++++++++++++++++--- 1 file changed, 70 insertions(+), 11 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index c4db4c2..17a81b1 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -275,8 +275,19 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: return base_df[base_data[0].keys()].to_dict("records") -class GraphPropertiesAsPerNodeType(DataPropertiesSetter, ABC): - READER = AtomFGReader_WithFGEdges_WithGraphNode +class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): + def __init__(self, properties=None, transform=None, **kwargs): + super().__init__(properties, transform, **kwargs) + # Sort properties so that AllNodeTypeProperty instances come first, rest of the properties order remain same + first = [ + prop for prop in self.properties if isinstance(prop, AllNodeTypeProperty) + ] + rest = [ + prop + for prop in self.properties + if not isinstance(prop, AllNodeTypeProperty) + ] + self.properties = first + rest def load_processed_data_from_file(self, filename: str) -> list[dict]: """ @@ -308,6 +319,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: if isinstance(prop, AllNodeTypeProperty): n_atom_node_properties += prop_length n_fg_node_properties += prop_length + n_graph_node_properties += prop_length props_categories["AllNodeTypeProperties"].append(prop_name) elif isinstance(prop, FGNodeTypeProperty): n_fg_node_properties += prop_length @@ -354,7 +366,11 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: base_df = base_df.merge(property_df, on="ident", how="left") base_df["features"] = base_df.apply( - lambda row: self._merge_props_into_base(row), axis=1 + lambda row: self._merge_props_into_base( + row, + max_len_node_properties=n_atom_properties, + ), + axis=1, ) # apply transformation, e.g. masking for pretraining task @@ -363,7 +379,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: return base_df[base_data[0].keys()].to_dict("records") - def _merge_props_into_base(self, row: pd.Series) -> GeomData: + def _merge_props_into_base( + self, row: pd.Series, max_len_node_properties: int + ) -> GeomData: """ Merge encoded molecular properties into the GeomData object. @@ -375,14 +393,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: """ geom_data = row["features"] assert isinstance(geom_data, GeomData) + is_atom_node = geom_data.is_atom_node assert is_atom_node is not None, "`is_atom_node` must be set in the geom_data" is_graph_node = geom_data.is_graph_node assert is_graph_node is not None, "`is_graph_node` must be set in the geom_data" + is_fg_node = ~is_atom_node & ~is_graph_node + num_nodes = geom_data.x.size(0) edge_attr = geom_data.edge_attr - x = geom_data.x - molecule_attr = torch.empty((1, 0)) + + # Initialize node feature matrix + assert ( + max_len_node_properties is not None + ), "Maximum len of node properties should not be None" + x = torch.zeros((num_nodes, max_len_node_properties)) + + # Track column offsets for each node type + atom_offset, fg_offset, graph_offset = 0, 0, 0 for property in self.properties: property_values = row[f"{property.name}"] @@ -396,24 +424,51 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: (0, property.encoder.get_encoding_length()) ) - if isinstance(property, AtomProperty): - x = torch.cat([x, property_values], dim=1) + enc_len = property_values.shape[1] + # -------------- Node properties --------------- + if isinstance(property, AllNodeTypeProperty): + x[:, atom_offset : atom_offset + enc_len] = property_values + atom_offset += enc_len + fg_offset += enc_len + graph_offset += enc_len + + elif isinstance(property, AtomNodeTypeProperty): + x[is_atom_node, atom_offset : atom_offset + enc_len] = property_values[ + is_atom_node + ] + atom_offset += enc_len + + elif isinstance(property, FGNodeTypeProperty): + x[is_fg_node, fg_offset : fg_offset + enc_len] = property_values[ + is_fg_node + ] + fg_offset += enc_len + + elif isinstance(property, MoleculeProperty): + x[is_graph_node, graph_offset : graph_offset + enc_len] = ( + property_values[is_graph_node] + ) + graph_offset += enc_len + + # ------------- Bond Properties -------------- elif isinstance(property, BondProperty): # Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges edge_attr = torch.cat( [edge_attr, torch.cat([property_values, property_values], dim=0)], dim=1, ) - elif isinstance(property, MoleculeProperty): - molecule_attr = torch.cat([molecule_attr, property_values], dim=1) else: raise TypeError(f"Unsupported property type: {type(property).__name__}") + total_used_columns = max(atom_offset, fg_offset, graph_offset) + assert ( + total_used_columns <= max_len_node_properties + ), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}" + return GeomData( x=x, edge_index=geom_data.edge_index, edge_attr=edge_attr, - molecule_attr=molecule_attr, ) @@ -507,3 +562,7 @@ class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver5 """ChEBIOver50 with atom-level nodes and graph node only.""" READER = AtomReader_WithGraphNodeOnly + + +class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50): + READER = AtomFGReader_WithFGEdges_WithGraphNode From dfdd8108a85fae069deb542133f14372a55301e3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Jul 2025 23:48:17 +0200 Subject: [PATCH 188/224] use pyg model impl without changing the architecture - https://github.com/ChEB-AI/python-chebai-graph/issues/12 --- chebai_graph/models/base.py | 20 +- chebai_graph/models/gat.py | 11 +- chebai_graph/models/resgated.py | 183 +++++-------------- chebai_graph/preprocessing/datasets/chebi.py | 12 +- configs/model/gat.yml | 16 +- configs/model/resgated.yml | 11 +- 6 files changed, 81 insertions(+), 172 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 5d5eb6f..1ddc784 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -49,20 +49,22 @@ def __init__(self, config: dict, **kwargs) -> None: Args: config (dict): Configuration dictionary with keys: - - 'hidden_length' - - 'dropout_rate' - - 'n_conv_layers' - - 'n_node_properties' - - 'n_bond_properties' + - 'num_layers' + - 'in_channels' + - 'hidden_channels' + - 'out_channels' + - 'edge_dim' + - 'dropout' **kwargs: Additional keyword arguments for torch.nn.Module. """ super().__init__(**kwargs) - self.hidden_channels = int(config["hidden_channels"]) - self.out_channels = int(config["out_channels"]) self.num_layers = int(config["num_layers"]) assert self.num_layers > 1, "Need atleast two convolution layers" - self.n_node_properties = int(config["n_node_properties"]) # in_channels - self.n_bond_properties = int(config["n_bond_properties"]) # edge_dim + self.in_channels = int(config["in_channels"]) # number of node/atom properties + self.hidden_channels = int(config["hidden_channels"]) + self.out_channels = int(config["out_channels"]) + self.edge_dim = int(config["edge_dim"]) # number of bond properties + self.dropout = float(config["dropout"]) class GraphNetWrapper(GraphBaseNet, ABC): diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index c289e9e..230d3c5 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -30,11 +30,12 @@ def __init__(self, config: dict, **kwargs): self.v2 = bool(config["v2"]) self.activation = ELU() # Instantiate ELU once for reuse. self.gat = GAT( - in_channels=self.n_atom_properties, - hidden_channels=self.hidden_length, - num_layers=self.n_conv_layers, - dropout=self.dropout_rate, - edge_dim=self.n_bond_properties, + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + dropout=self.dropout, + edge_dim=self.edge_dim, heads=self.heads, v2=self.v2, act=self.activation, diff --git a/chebai_graph/models/resgated.py b/chebai_graph/models/resgated.py index b07531d..521a853 100644 --- a/chebai_graph/models/resgated.py +++ b/chebai_graph/models/resgated.py @@ -1,7 +1,6 @@ -from typing import Final, Tuple, Union +from typing import Any, Final -import torch -import torch.nn.functional as F +from torch import Tensor from torch.nn import ELU from torch_geometric import nn as tgnn from torch_geometric.data import Data as GraphData @@ -11,116 +10,34 @@ from .base import GraphModelBase, GraphNetWrapper -class ResGatedGraphConvNetBase(GraphModelBase): - """ - Residual Gated Graph Convolutional Network with edge attributes support. - - This model uses a stack of `ResGatedGraphConv` layers from PyTorch Geometric, - allowing edge attributes as part of message passing. A final projection layer maps - to the hidden length specified for downstream graph prediction tasks. - """ - - NAME = "ResGatedGraphConvNetBase" - - def __init__(self, config: dict, **kwargs): - """ - Initialize the ResGatedGraphConvNetBase. - - Args: - config (dict): Configuration dictionary with keys: - - 'hidden_length' (int): Intermediate feature length used in GNN layers. - - Other parameters inherited from GraphModelBase. - **kwargs: Additional keyword arguments passed to GraphModelBase. - """ - super().__init__(config=config, **kwargs) - - self.activation = F.elu - self.convs = torch.nn.ModuleList() - self.convs.append( - tgnn.ResGatedGraphConv( - self.n_node_properties, - self.hidden_channels, - # dropout=self.dropout, - edge_dim=self.n_bond_properties, - ) - ) - - for _ in range(self.num_layers - 2): - # Intermediate layers - self.convs.append( - tgnn.ResGatedGraphConv( - self.hidden_channels, - self.hidden_channels, - edge_dim=self.n_bond_properties, - ) - ) - - # Final projection layer to hidden dimension - self.final_conv = tgnn.ResGatedGraphConv( - self.hidden_channels, self.out_channels, edge_dim=self.n_bond_properties - ) - - def forward(self, batch: dict) -> torch.Tensor: - """ - Forward pass through residual gated GNN layers. - - Args: - batch (dict): A batch containing: - - 'features': A list with a `GraphData` instance as the first element. - - Returns: - torch.Tensor: Node-level embeddings of shape [num_nodes, hidden_length]. - """ - graph_data = batch["features"][0] - assert isinstance(graph_data, GraphData) - - x = graph_data.x.float() # Atom features - - for conv in self.convs: - assert isinstance(conv, tgnn.ResGatedGraphConv) - x = self.activation( - conv(x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr) - ) - - x = self.activation( - self.final_conv( - x, graph_data.edge_index.long(), edge_attr=graph_data.edge_attr - ) - ) - - return x - - -class ResGatedGraphPred(GraphNetWrapper): +class ResGatedModel(BasicGNN): """ - Residual Gated GNN for Graph Prediction. + A residual gated GNN model based on PyG's BasicGNN using ResGatedGraphConv layers. - Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings. + Attributes: + supports_edge_weight (bool): Indicates edge weights are not supported. + supports_edge_attr (bool): Indicates edge attributes are supported. + supports_norm_batch (bool): Indicates if batch normalization is supported. """ - NAME = "ResGatedGraphPred" + supports_edge_weight: Final[bool] = False + supports_edge_attr: Final[bool] = True + supports_norm_batch: Final[bool] - def _get_gnn(self, config: dict) -> ResGatedGraphConvNetBase: + def init_conv( + self, in_channels: int | tuple[int, int], out_channels: int, **kwargs: Any + ) -> MessagePassing: """ - Instantiate the residual gated GNN backbone. + Initializes a ResGatedGraphConv layer. Args: - config (dict): Model configuration. + in_channels (int or Tuple[int, int]): Number of input channels. + out_channels (int): Number of output channels. + **kwargs: Additional keyword arguments for the convolution layer. Returns: - ResGatedGraphConvNetBase: The GNN encoder. + MessagePassing: A ResGatedGraphConv layer instance. """ - return ResGatedGraphConvNetBase(config=config) - - -class ResGatedModel(BasicGNN): - supports_edge_weight: Final[bool] = False - supports_edge_attr: Final[bool] = True - supports_norm_batch: Final[bool] - - def init_conv( - self, in_channels: Union[int, Tuple[int, int]], out_channels: int, **kwargs - ) -> MessagePassing: return tgnn.ResGatedGraphConv( in_channels, out_channels, @@ -128,54 +45,42 @@ def init_conv( ) -class ResGatedPyG(GraphModelBase): +class ResGatedGraphConvNetBase(GraphModelBase): """ - Graph Attention Network (GAT) base module for graph convolution. + Base model class for applying ResGatedGraphConv layers to graph-structured data. - Uses PyTorch Geometric's `GAT` implementation to process atomic node features - and bond edge attributes through multiple attention heads and layers. + Args: + config (dict): Configuration dictionary containing model hyperparameters. + **kwargs: Additional keyword arguments for parent class. """ - def __init__(self, config: dict, **kwargs): - """ - Initialize the GATGraphConvNetBase. - - Args: - config (dict): Model configuration containing: - - 'heads' (int): Number of attention heads. - - 'v2' (bool): Whether to use the GATv2 variant. - - Other required GraphModelBase parameters. - **kwargs: Additional arguments for the base class. - """ + def __init__(self, config: dict[str, Any], **kwargs: Any): super().__init__(config=config, **kwargs) self.activation = ELU() # Instantiate ELU once for reuse. - self.gat = ResGatedModel( - in_channels=self.n_node_properties, + + self.resgated: BasicGNN = ResGatedModel( + in_channels=self.in_channels, hidden_channels=self.hidden_channels, out_channels=self.out_channels, num_layers=self.num_layers, - edge_dim=self.n_bond_properties, + edge_dim=self.edge_dim, act=self.activation, ) - def forward(self, batch: dict) -> torch.Tensor: + def forward(self, batch: dict[str, Any]) -> Tensor: """ - Forward pass through the GAT network. - - Processes atomic node features and edge attributes, and applies - an ELU activation to the output. + Forward pass of the model. Args: - batch (dict): Input batch containing: - - 'features': A list with a `GraphData` object as its first element. + batch (dict): A batch containing graph input features under the key "features". Returns: - torch.Tensor: Node embeddings after GAT and activation. + Tensor: The output node-level embeddings after the final activation. """ graph_data = batch["features"][0] - assert isinstance(graph_data, GraphData) + assert isinstance(graph_data, GraphData), "Expected GraphData instance" - out = self.gat( + out = self.resgated( x=graph_data.x.float(), edge_index=graph_data.edge_index.long(), edge_attr=graph_data.edge_attr, @@ -184,23 +89,21 @@ def forward(self, batch: dict) -> torch.Tensor: return self.activation(out) -class ResGatedGraphPredPyG(GraphNetWrapper): +class ResGatedGraphPred(GraphNetWrapper): """ - Residual Gated GNN for Graph Prediction. + Wrapper for graph-level prediction using ResGatedGraphConvNetBase. - Uses `ResGatedGraphConvNetBase` as the GNN encoder to compute node embeddings. + This class instantiates the core GNN model using the provided config. """ - NAME = "ResGatedGraphPred" - - def _get_gnn(self, config: dict) -> ResGatedPyG: + def _get_gnn(self, config: dict[str, Any]) -> ResGatedGraphConvNetBase: """ - Instantiate the residual gated GNN backbone. + Returns the core ResGated GNN model. Args: - config (dict): Model configuration. + config (dict): Configuration dictionary for the GNN model. Returns: - ResGatedGraphConvNetBase: The GNN encoder. + ResGatedGraphConvNetBase: The core graph convolutional network. """ - return ResGatedPyG(config=config) + return ResGatedGraphConvNetBase(config=config) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 17a81b1..d2b788c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -267,8 +267,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: ] rank_zero_info( f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" - f"Use n_atom_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, " - f"n_bond_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " + f"Use following values for given parameters for model configuration: \n\t" + f"in_channels: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, " + f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" ) @@ -337,7 +338,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: else: raise TypeError(f"Unsupported property type: {type(prop).__name__}") - n_atom_properties = max( + n_node_properties = max( n_atom_node_properties, n_fg_node_properties, n_graph_node_properties ) rank_zero_info( @@ -347,7 +348,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: f"n_fg_node_properties: {n_fg_node_properties}, " f"n_bond_properties: {n_bond_properties}, " f"n_graph_node_properties: {n_graph_node_properties}\n" - f"Use n_atom_properties: {n_atom_properties}, n_bond_properties: {n_bond_properties}, n_molecule_properties: 0" + f"Use following values for given parameters for model configuration: \n\t" + f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0" ) for property in self.properties: @@ -368,7 +370,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: base_df["features"] = base_df.apply( lambda row: self._merge_props_into_base( row, - max_len_node_properties=n_atom_properties, + max_len_node_properties=n_node_properties, ), axis=1, ) diff --git a/configs/model/gat.yml b/configs/model/gat.yml index 688edb2..a8cfbe0 100644 --- a/configs/model/gat.yml +++ b/configs/model/gat.yml @@ -3,13 +3,13 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - hidden_length: 512 - dropout_rate: 0 - n_conv_layers: 3 + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 5 + edge_dim: 7 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv - n_atom_properties: 158 - n_bond_properties: 7 - - n_molecule_properties: 200 - n_linear_layers: 3 + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 2 diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml index 6c6e4aa..244f471 100644 --- a/configs/model/resgated.yml +++ b/configs/model/resgated.yml @@ -3,10 +3,11 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - n_node_properties: 68 # in_channels - hidden_channels : 256 - out_channels : 512 - num_layers : 4 - n_bond_properties: 4 # edge_dim + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 5 + edge_dim: 7 # number of bond properties + dropout: 0 n_molecule_properties: 0 n_linear_layers: 2 From ca3f3050a7abcb7de7ec9da2b1a8200032febc06 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Jul 2025 11:45:08 +0200 Subject: [PATCH 189/224] zero padding for atom properties --- chebai_graph/preprocessing/datasets/chebi.py | 23 +++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index d2b788c..03fab5e 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -178,6 +178,17 @@ def _after_setup(self, **kwargs) -> None: class GraphPropertiesMixIn(DataPropertiesSetter, ABC): + def __init__( + self, properties=None, transform=None, zero_pad_atom: int = None, **kwargs + ): + super().__init__(properties, transform, **kwargs) + self.zero_pad_atom = int(zero_pad_atom) if zero_pad_atom is not None else None + if self.zero_pad_atom: + print( + f"[Info] Atom-level features will be zero-padded with " + f"{self.zero_pad_atom} additional dimensions." + ) + def _merge_props_into_base(self, row: pd.Series) -> GeomData: """ Merge encoded molecular properties into the GeomData object. @@ -219,6 +230,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: else: raise TypeError(f"Unsupported property type: {type(property).__name__}") + if self.zero_pad_atom is not None: + x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_atom))], dim=1) + return GeomData( x=x, edge_index=geom_data.edge_index, @@ -265,10 +279,17 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: prop_lengths = [ (prop.name, prop.encoder.get_encoding_length()) for prop in self.properties ] + n_node_properties = sum( + p.encoder.get_encoding_length() + for p in self.properties + if isinstance(p, AtomProperty) + ) + if self.zero_pad_atom: + n_node_properties += self.zero_pad_atom rank_zero_info( f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" f"Use following values for given parameters for model configuration: \n\t" - f"in_channels: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, " + f"in_channels: {n_node_properties} (with {self.zero_pad_atom} padded zeros) , " f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" ) From f5d3cb81e11de4163244da5c7a7815cee5095208 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Jul 2025 12:31:49 +0200 Subject: [PATCH 190/224] sort as per node prop to same seq as baseline for fair comparison --- chebai_graph/preprocessing/datasets/chebi.py | 10 +++++++--- configs/data/chebi50_aug_prop_as_per_node.yml | 10 +++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 03fab5e..9afb74f 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -76,9 +76,6 @@ def __init__( assert isinstance(self.properties, list) and all( isinstance(p, MolecularProperty) for p in self.properties ) - rank_zero_info( - f"Data module uses these properties (ordered): {', '.join([str(p) for p in properties])}" - ) self.transform = transform def _setup_properties(self) -> None: @@ -188,6 +185,9 @@ def __init__( f"[Info] Atom-level features will be zero-padded with " f"{self.zero_pad_atom} additional dimensions." ) + print( + f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" + ) def _merge_props_into_base(self, row: pd.Series) -> GeomData: """ @@ -310,6 +310,10 @@ def __init__(self, properties=None, transform=None, **kwargs): if not isinstance(prop, AllNodeTypeProperty) ] self.properties = first + rest + print( + "Properties are sorted so that `AllNodeTypeProperty` properties are first in sequence and rest of the order remains same\n", + f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}", + ) def load_processed_data_from_file(self, filename: str) -> list[dict]: """ diff --git a/configs/data/chebi50_aug_prop_as_per_node.yml b/configs/data/chebi50_aug_prop_as_per_node.yml index effb04a..dcd77be 100644 --- a/configs/data/chebi50_aug_prop_as_per_node.yml +++ b/configs/data/chebi50_aug_prop_as_per_node.yml @@ -4,12 +4,12 @@ init_args: # All Node type properties - chebai_graph.preprocessing.properties.AtomNodeLevel # Atom Node type properties - - chebai_graph.preprocessing.properties.AugAtomType - - chebai_graph.preprocessing.properties.AugNumAtomBonds - - chebai_graph.preprocessing.properties.AugAtomCharge - chebai_graph.preprocessing.properties.AugAtomAromaticity + - chebai_graph.preprocessing.properties.AugAtomCharge - chebai_graph.preprocessing.properties.AugAtomHybridization - chebai_graph.preprocessing.properties.AugAtomNumHs + - chebai_graph.preprocessing.properties.AugAtomType + - chebai_graph.preprocessing.properties.AugNumAtomBonds # FG Node type properties - chebai_graph.preprocessing.properties.AtomFunctionalGroup - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG @@ -19,6 +19,6 @@ init_args: - chebai_graph.preprocessing.properties.RDKit2DNormalized # Bond properties - chebai_graph.preprocessing.properties.BondLevel - - chebai_graph.preprocessing.properties.AugBondType - - chebai_graph.preprocessing.properties.AugBondInRing - chebai_graph.preprocessing.properties.AugBondAromaticity + - chebai_graph.preprocessing.properties.AugBondInRing + - chebai_graph.preprocessing.properties.AugBondType From e71dfa2346ce74279c86e93d2c0742bc16de5f24 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 26 Jul 2025 23:30:25 +0200 Subject: [PATCH 191/224] add string-based sorting for `GraphPropAsPerNodeType` too https://github.com/ChEB-AI/python-chebai-graph/issues/19 --- chebai_graph/preprocessing/datasets/chebi.py | 27 ++++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 9afb74f..5ab1253 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -67,9 +67,7 @@ def __init__( # atom_properties and bond_properties are given as lists containing class_paths if properties is not None: properties = [resolve_property(prop) for prop in properties] - properties = sorted( - properties, key=lambda prop: self.get_property_path(prop) - ) + properties = self._sort_properties(properties) else: properties = [] self.properties = properties @@ -78,6 +76,11 @@ def __init__( ) self.transform = transform + def _sort_properties( + self, properties: list[MolecularProperty] + ) -> list[MolecularProperty]: + return sorted(properties, key=lambda prop: self.get_property_path(prop)) + def _setup_properties(self) -> None: """ Process and cache molecular properties to disk. @@ -301,14 +304,16 @@ class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): super().__init__(properties, transform, **kwargs) # Sort properties so that AllNodeTypeProperty instances come first, rest of the properties order remain same - first = [ - prop for prop in self.properties if isinstance(prop, AllNodeTypeProperty) - ] - rest = [ - prop - for prop in self.properties - if not isinstance(prop, AllNodeTypeProperty) - ] + first = self._sort_properties( + [prop for prop in self.properties if isinstance(prop, AllNodeTypeProperty)] + ) + rest = self._sort_properties( + [ + prop + for prop in self.properties + if not isinstance(prop, AllNodeTypeProperty) + ] + ) self.properties = first + rest print( "Properties are sorted so that `AllNodeTypeProperty` properties are first in sequence and rest of the order remains same\n", From e30e3acbeaee6dfef470e256050f4b03c85527fc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 27 Jul 2025 12:45:39 +0200 Subject: [PATCH 192/224] static gni --- .../preprocessing/datasets/__init__.py | 2 + chebai_graph/preprocessing/datasets/chebi.py | 27 ++++++- chebai_graph/preprocessing/reader/__init__.py | 2 + .../preprocessing/reader/static_gni.py | 75 +++++++++++++++++++ configs/data/chebi50_static_gni.yml | 7 ++ 5 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 chebai_graph/preprocessing/reader/static_gni.py create mode 100644 configs/data/chebi50_static_gni.yml diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index 5f8ae1e..01c14fd 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -2,6 +2,7 @@ ChEBI50_Atom_WGNOnly_GraphProp, ChEBI50_NFGE_NGN_GraphProp, ChEBI50_NFGE_WGN_GraphProp, + ChEBI50_StaticGNI, ChEBI50_WFGE_NGN_GraphProp, ChEBI50_WFGE_WGN_GraphProp, ChEBI50GraphData, @@ -19,4 +20,5 @@ "ChEBI50_NFGE_WGN_GraphProp", "ChEBI50_WFGE_NGN_GraphProp", "ChEBI50_WFGE_WGN_GraphProp", + "ChEBI50_StaticGNI", ] diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 5ab1253..e278b59 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -31,6 +31,7 @@ AtomsFGReader_NoFGEdges_NoGraphNode, GraphPropertyReader, GraphReader, + RandomNodeInitializationReader, ) from .utils import resolve_property @@ -188,9 +189,11 @@ def __init__( f"[Info] Atom-level features will be zero-padded with " f"{self.zero_pad_atom} additional dimensions." ) - print( - f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" - ) + + if self.properties: + print( + f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" + ) def _merge_props_into_base(self, row: pd.Series) -> GeomData: """ @@ -504,6 +507,24 @@ def _merge_props_into_base( ) +class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): + READER = RandomNodeInitializationReader + + def _setup_properties(self): ... + + def load_processed_data_from_file(self, filename): + base_data = super().load_processed_data_from_file(filename) + base_df = pd.DataFrame(base_data) + + rank_zero_info( + f"Use following values for given parameters for model configuration: \n\t" + f"in_channels: {self.reader.num_node_properties} , " + f"edge_dim: {self.reader.num_bond_properties}, " + f"n_molecule_properties: {self.reader.num_molecule_properties}" + ) + return base_df[base_data[0].keys()].to_dict("records") + + class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50): """ChEBIOver50 dataset with molecular property encodings.""" diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index 12df70b..9e58e32 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -6,6 +6,7 @@ AtomsFGReader_NoFGEdges_NoGraphNode, ) from .reader import GraphPropertyReader, GraphReader +from .static_gni import RandomNodeInitializationReader __all__ = [ "GraphReader", @@ -15,4 +16,5 @@ "AtomFGReader_NoFGEdges_WithGraphNode", "AtomFGReader_WithFGEdges_NoGraphNode", "AtomFGReader_WithFGEdges_WithGraphNode", + "RandomNodeInitializationReader", ] diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py new file mode 100644 index 0000000..68a4e8e --- /dev/null +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -0,0 +1,75 @@ +""" +Abboud, Ralph, et al. +"The surprising power of graph neural networks with random node initialization." +arXiv preprint arXiv:2010.01179 (2020). + +Code Reference: https://github.com/ralphabb/GNN-RNI/blob/main/GNNHyb.py +""" + +import torch +from torch_geometric.data import Data as GeomData + +from .reader import GraphPropertyReader + + +class RandomNodeInitializationReader(GraphPropertyReader): + def __init__( + self, + num_node_properties: int, + num_bond_properties: int, + num_molecule_properties: int, + distribution: str, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.num_node_properties = num_node_properties + self.num_bond_properties = num_bond_properties + self.num_molecule_properties = num_molecule_properties + assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] + self.distribution = distribution + + def name(self) -> str: + """ + Get the name identifier of the reader. + + Returns: + str: The name of the reader. + """ + return f"gni-{self.distribution}-node{self.num_node_properties}-bond{self.num_bond_properties}-mol{self.num_molecule_properties}" + + def _read_data(self, raw_data): + data: GeomData = super()._read_data(raw_data) + random_x = torch.empty(data.x.shape[0], self.num_node_properties) + random_edge_attr = torch.empty( + data.edge_index.shape[1], self.num_bond_properties + ) + random_molecule_properties = torch.empty(1, self.num_molecule_properties) + + if self.distribution == "normal": + torch.nn.init.normal_(random_x) + torch.nn.init.normal_(random_edge_attr) + torch.nn.init.normal_(random_molecule_properties) + elif self.distribution == "uniform": + torch.nn.init.uniform_(random_x, a=-1.0, b=1.0) + torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0) + torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) + elif self.distribution == "xavier_normal": + torch.nn.init.xavier_normal_(random_x) + torch.nn.init.xavier_normal_(random_edge_attr) + torch.nn.init.xavier_normal_(random_molecule_properties) + elif self.distribution == "xavier_uniform": + torch.nn.init.xavier_uniform_(random_x) + torch.nn.init.xavier_uniform_(random_edge_attr) + torch.nn.init.xavier_uniform_(random_molecule_properties) + else: + raise ValueError("Unknown distribution type") + + data.x = random_x + data.edge_attr = random_edge_attr + data.molecule_attr = random_molecule_properties + return data + + def read_property(self, *args, **kwargs) -> Exception: + """This reader does not support reading specific properties.""" + raise NotImplementedError("This reader only performs random initialization.") diff --git a/configs/data/chebi50_static_gni.yml b/configs/data/chebi50_static_gni.yml new file mode 100644 index 0000000..1509802 --- /dev/null +++ b/configs/data/chebi50_static_gni.yml @@ -0,0 +1,7 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI50_StaticGNI +init_args: + reader_kwargs: + num_node_properties: 158 + num_bond_properties: 7 + num_molecule_properties: 200 + distribution: normal From de0c97f1970d9aee370b919412dfafccdbe4dbaa Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 28 Jul 2025 09:05:35 +0200 Subject: [PATCH 193/224] add as per node class to init --- chebai_graph/preprocessing/datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index 01c14fd..c821c5c 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -4,6 +4,7 @@ ChEBI50_NFGE_WGN_GraphProp, ChEBI50_StaticGNI, ChEBI50_WFGE_NGN_GraphProp, + ChEBI50_WFGE_WGN_AsPerNodeType, ChEBI50_WFGE_WGN_GraphProp, ChEBI50GraphData, ChEBI50GraphProperties, @@ -21,4 +22,5 @@ "ChEBI50_WFGE_NGN_GraphProp", "ChEBI50_WFGE_WGN_GraphProp", "ChEBI50_StaticGNI", + "ChEBI50_WFGE_WGN_AsPerNodeType", ] From 84c170b625c8e87bf8c8bd2804a70eb8955ad093 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 28 Jul 2025 10:42:53 +0200 Subject: [PATCH 194/224] add aug rdkit norm prop --- chebai_graph/preprocessing/datasets/chebi.py | 3 ++- chebai_graph/preprocessing/properties/__init__.py | 2 ++ .../properties/augmented_properties.py | 7 +++++++ chebai_graph/preprocessing/properties/base.py | 14 ++++++++++++++ configs/data/chebi50_aug_prop_as_per_node.yml | 2 +- 5 files changed, 26 insertions(+), 2 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index e278b59..a88492f 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -1,6 +1,7 @@ import os from abc import ABC from collections.abc import Callable +from pprint import pformat import pandas as pd import torch @@ -376,7 +377,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: ) rank_zero_info( f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" - f"Properties Categories {props_categories}\n" + f"Properties Categories:\n{pformat(props_categories)}" f"n_atom_node_properties: {n_atom_node_properties}, " f"n_fg_node_properties: {n_fg_node_properties}, " f"n_bond_properties: {n_bond_properties}, " diff --git a/chebai_graph/preprocessing/properties/__init__.py b/chebai_graph/preprocessing/properties/__init__.py index 9b6b393..a0de30d 100644 --- a/chebai_graph/preprocessing/properties/__init__.py +++ b/chebai_graph/preprocessing/properties/__init__.py @@ -43,6 +43,7 @@ AugBondAromaticity, AugBondType, AugBondInRing, + AugRDKit2DNormalized, ) # isort: on @@ -84,4 +85,5 @@ "AugBondAromaticity", "AugBondType", "AugBondInRing", + "AugRDKit2DNormalized", ] diff --git a/chebai_graph/preprocessing/properties/augmented_properties.py b/chebai_graph/preprocessing/properties/augmented_properties.py index b08d36b..f5f7b1d 100644 --- a/chebai_graph/preprocessing/properties/augmented_properties.py +++ b/chebai_graph/preprocessing/properties/augmented_properties.py @@ -14,6 +14,7 @@ AllNodeTypeProperty, AtomNodeTypeProperty, AugmentedBondProperty, + AugmentedMoleculeProperty, FGNodeTypeProperty, FrozenPropertyAlias, ) @@ -402,3 +403,9 @@ class AugBondInRing(AugBondValueDefaulter, pr.BondInRing): """ ... + + +# --------------------- Molecule Properties ------------------------------ + + +class AugRDKit2DNormalized(AugmentedMoleculeProperty, pr.RDKit2DNormalized): ... diff --git a/chebai_graph/preprocessing/properties/base.py b/chebai_graph/preprocessing/properties/base.py index e480b84..da5d9c2 100644 --- a/chebai_graph/preprocessing/properties/base.py +++ b/chebai_graph/preprocessing/properties/base.py @@ -439,3 +439,17 @@ def _get_bond_prop_value(bond: Chem.rdchem.Bond | dict, prop: str) -> str: return bond[prop] else: raise TypeError("Bond/Edge should be of type `Chem.rdchem.Bond` or `dict`.") + + +class AugmentedMoleculeProperty(MoleculeProperty, ABC): + def get_property_value(self, augmented_mol: dict) -> list: + """ + Get molecular property values from augmented molecule dict. + Args: + augmented_mol (dict): Augmented molecule dict. + Returns: + list: Property values of molecule. + """ + mol: Chem.Mol = augmented_mol[AugmentedAtomProperty.MAIN_KEY]["atom_nodes"] + assert isinstance(mol, Chem.Mol), "Molecule should be instance of `Chem.Mol`" + return super().get_property_value(mol) diff --git a/configs/data/chebi50_aug_prop_as_per_node.yml b/configs/data/chebi50_aug_prop_as_per_node.yml index dcd77be..576cf75 100644 --- a/configs/data/chebi50_aug_prop_as_per_node.yml +++ b/configs/data/chebi50_aug_prop_as_per_node.yml @@ -16,7 +16,7 @@ init_args: - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - chebai_graph.preprocessing.properties.IsFGAlkyl # Graph Node type properties - - chebai_graph.preprocessing.properties.RDKit2DNormalized + - chebai_graph.preprocessing.properties.AugRDKit2DNormalized # Bond properties - chebai_graph.preprocessing.properties.BondLevel - chebai_graph.preprocessing.properties.AugBondAromaticity From 4a1e82b6b115574c6eb7d64d60c47511f7e8c978 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 29 Jul 2025 00:13:09 +0200 Subject: [PATCH 195/224] fix mol props for as per node data cls --- chebai_graph/preprocessing/datasets/chebi.py | 13 +++++++------ chebai_graph/preprocessing/properties/properties.py | 6 ++++++ chebai_graph/preprocessing/property_encoder.py | 12 ++++++++++-- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index a88492f..7356600 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -376,14 +376,14 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: n_atom_node_properties, n_fg_node_properties, n_graph_node_properties ) rank_zero_info( - f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" - f"Properties Categories:\n{pformat(props_categories)}" + f"\nFinished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n\n" + f"Properties Categories:\n{pformat(props_categories)}\n\n" f"n_atom_node_properties: {n_atom_node_properties}, " f"n_fg_node_properties: {n_fg_node_properties}, " f"n_bond_properties: {n_bond_properties}, " - f"n_graph_node_properties: {n_graph_node_properties}\n" + f"n_graph_node_properties: {n_graph_node_properties}\n\n" f"Use following values for given parameters for model configuration: \n\t" - f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0" + f"in_channels: {n_node_properties}, edge_dim: {n_bond_properties}, n_molecule_properties: 0\n" ) for property in self.properties: @@ -449,7 +449,7 @@ def _merge_props_into_base( atom_offset, fg_offset, graph_offset = 0, 0, 0 for property in self.properties: - property_values = row[f"{property.name}"] + property_values = row[f"{property.name}"].to(dtype=torch.float32) if isinstance(property_values, torch.Tensor): if len(property_values.size()) == 0: property_values = property_values.unsqueeze(0) @@ -482,7 +482,7 @@ def _merge_props_into_base( elif isinstance(property, MoleculeProperty): x[is_graph_node, graph_offset : graph_offset + enc_len] = ( - property_values[is_graph_node] + property_values ) graph_offset += enc_len @@ -505,6 +505,7 @@ def _merge_props_into_base( x=x, edge_index=geom_data.edge_index, edge_attr=edge_attr, + molecule_attr=torch.empty((1, 0)), # empty as not used for this class ) diff --git a/chebai_graph/preprocessing/properties/properties.py b/chebai_graph/preprocessing/properties/properties.py index 19e8a4e..b76f244 100644 --- a/chebai_graph/preprocessing/properties/properties.py +++ b/chebai_graph/preprocessing/properties/properties.py @@ -275,6 +275,12 @@ class RDKit2DNormalized(MoleculeProperty): def __init__(self, encoder: PropertyEncoder | None = None) -> None: super().__init__(encoder or AsIsEncoder(self)) self.generator_normalized = rdNormalizedDescriptors.RDKit2DNormalized() + # Create a dummy molecule (e.g., methane) to extract the length of descriptor vector + dummy_mol = Chem.MolFromSmiles("C") + descr_values = self.generator_normalized.processMol( + dummy_mol, Chem.MolToSmiles(dummy_mol) + ) + self.encoder.set_encoding_length(len(descr_values) - 1) def get_property_value(self, mol: Chem.rdchem.Mol) -> list[np.ndarray]: """ diff --git a/chebai_graph/preprocessing/property_encoder.py b/chebai_graph/preprocessing/property_encoder.py index aff1bde..1487163 100644 --- a/chebai_graph/preprocessing/property_encoder.py +++ b/chebai_graph/preprocessing/property_encoder.py @@ -243,8 +243,16 @@ def encode(self, token: float | int | None) -> torch.Tensor: Tensor of shape (1,) containing the input value or zero. """ if token is None: - return torch.tensor([0]) - return torch.tensor([token]) + return torch.zeros(1, self.get_encoding_length()) + assert ( + len(token) == self.get_encoding_length() + ), "Length of token should be equal to encoding length" + # return torch.tensor([token]) # token is an ndarray, no need to create list of ndarray due to below warning + # UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. + # Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. + # (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\tensor_new.cpp:257.) + # ----- fix: for above warning + return torch.tensor(token).unsqueeze(0) # shape: (1, len(token)) class BoolEncoder(PropertyEncoder): From b2055602626b0d52c62bc01a3e462b8b3e9f26a7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 3 Aug 2025 14:30:57 +0200 Subject: [PATCH 196/224] add few more ablation study aug readers --- .../bin/BondLevel/indices_one_hot.txt | 2 +- .../preprocessing/properties/constants.py | 2 +- chebai_graph/preprocessing/reader/__init__.py | 8 ++ .../preprocessing/reader/augmented_reader.py | 135 +++++++++++++++++- .../utils/visualize_augmented_molecule.py | 20 ++- 5 files changed, 153 insertions(+), 14 deletions(-) diff --git a/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt index c5f7ed0..b389485 100644 --- a/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt +++ b/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt @@ -1,4 +1,4 @@ atom_fg_lvl -fg_graphNode_lvl +to_graphNode_lvl within_atoms_lvl within_fg_lvl diff --git a/chebai_graph/preprocessing/properties/constants.py b/chebai_graph/preprocessing/properties/constants.py index e73da46..e4cd2b9 100644 --- a/chebai_graph/preprocessing/properties/constants.py +++ b/chebai_graph/preprocessing/properties/constants.py @@ -8,6 +8,6 @@ WITHIN_ATOMS_EDGE = "within_atoms_lvl" WITHIN_FG_EDGE = "within_fg_lvl" ATOM_FG_EDGE = "atom_fg_lvl" -TO_GRAPHNODE_EDGE = "fg_graphNode_lvl" +TO_GRAPHNODE_EDGE = "to_graphNode_lvl" EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, TO_GRAPHNODE_EDGE} NUM_EDGES = "num_undirected_edges" diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index 9e58e32..ee3fce0 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -4,6 +4,10 @@ AtomFGReader_WithFGEdges_WithGraphNode, AtomReader_WithGraphNodeOnly, AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, ) from .reader import GraphPropertyReader, GraphReader from .static_gni import RandomNodeInitializationReader @@ -17,4 +21,8 @@ "AtomFGReader_WithFGEdges_NoGraphNode", "AtomFGReader_WithFGEdges_WithGraphNode", "RandomNodeInitializationReader", + "GN_WithAtoms_FG_WithAtoms_FGE", + "GN_WithAtoms_FG_WithAtoms_NoFGE", + "GN_WithAllNodes_FG_WithAtoms_FGE", + "GN_WithAllNodes_FG_WithAtoms_NoFGE", ] diff --git a/chebai_graph/preprocessing/reader/augmented_reader.py b/chebai_graph/preprocessing/reader/augmented_reader.py index 4d0c22e..b6b2d90 100644 --- a/chebai_graph/preprocessing/reader/augmented_reader.py +++ b/chebai_graph/preprocessing/reader/augmented_reader.py @@ -307,13 +307,17 @@ def read_property(self, smiles: str, property: MolecularProperty) -> list | None class AtomsFGReader_NoFGEdges_NoGraphNode(_AugmentorReader): - """Adds FG nodes without intra-functional group edges and without introducing a graph-level node.""" + """ + Adds FG nodes (connected to their respective atom nodes) without + intra-functional group edges, and without introducing a graph-level node. + """ def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: """ - Constructs the full augmented graph structure from a molecule. + Constructs the full augmented graph structure from a molecule by adding + fg nodes to their respective atom nodes. Args: mol (Chem.Mol): RDKit molecule object. @@ -545,7 +549,10 @@ def _set_fg_prop( class AtomFGReader_WithFGEdges_NoGraphNode(AtomsFGReader_NoFGEdges_NoGraphNode): - """Adds FG nodes with intra-functional group edges and without introducing a graph-level node.""" + """ + Adds FG nodes (connected to their respective atom nodes) with intra-functional group + edges, and without introducing a graph-level node. + """ def _augment_graph_structure( self, mol: Chem.Mol @@ -754,13 +761,16 @@ def _construct_nodes_to_graph_node_structure( class AtomFGReader_WithFGEdges_WithGraphNode( AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode ): - """Adds FG nodes with intra-functional group edges and a graph-level node.""" + """ + Adds FG nodes (connected to their respective atom nodes) with intra-functional group + edges, and adds a graph-level node connected to all FG nodes. + """ def _augment_graph_structure( self, mol: Chem.Mol ) -> tuple[torch.Tensor, dict, dict]: """ - Augments the graph with FG edges and a global graph-level node. + Augments the graph with a global graph-level node. Args: mol (Chem.Mol): RDKit molecule object. @@ -778,7 +788,10 @@ def _augment_graph_structure( class AtomFGReader_NoFGEdges_WithGraphNode( AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode ): - """Adds FG nodes without functional group edges and a graph-level node.""" + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all FG nodes. + """ def _augment_graph_structure( self, mol: Chem.Mol @@ -818,3 +831,113 @@ def _augment_graph_structure( molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAtoms_FG_WithAtoms_NoFGE( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAtoms_FG_WithAtoms_FGE( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes(augmented_struct, atom_ids) + + +class GN_WithAllNodes_FG_WithAtoms_FGE( + AtomFGReader_WithFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, atom_ids | fg_to_atoms_map.keys() + ) + + +class GN_WithAllNodes_FG_WithAtoms_NoFGE( + AtomsFGReader_NoFGEdges_NoGraphNode, _AddGraphNode +): + """ + Adds FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + def _augment_graph_structure( + self, mol: Chem.Mol + ) -> tuple[torch.Tensor, dict, dict]: + """ + Augments the graph with a global graph-level node. + + Args: + mol (Chem.Mol): RDKit molecule object. + + Returns: + tuple[torch.Tensor, dict, dict]: Updated graph structure. + """ + augmented_struct = super()._augment_graph_structure(mol) + molecule: Chem.Mol = augmented_struct["node_info"]["atom_nodes"] + fg_to_atoms_map = augmented_struct["graph_meta_info"]["fg_to_atoms_map"] + atom_ids = {atom.GetIdx() for atom in molecule.GetAtoms()} + return self._add_graph_node_and_edges_to_nodes( + augmented_struct, atom_ids | fg_to_atoms_map.keys() + ) diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py index 897422d..fcc406b 100644 --- a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py +++ b/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py @@ -17,6 +17,10 @@ AtomFGReader_WithFGEdges_WithGraphNode, AtomReader_WithGraphNodeOnly, AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, ) matplotlib.use("TkAgg") @@ -51,6 +55,10 @@ "w_fge_n_gn": AtomFGReader_WithFGEdges_NoGraphNode, "n_fge_n_gn": AtomsFGReader_NoFGEdges_NoGraphNode, "atom_w_gn": AtomReader_WithGraphNodeOnly, + "gnwa_fgwa_nfge": GN_WithAtoms_FG_WithAtoms_NoFGE, + "gnwa_fgwa_wfge": GN_WithAtoms_FG_WithAtoms_FGE, + "gn_wall_fgwa_nfge": GN_WithAllNodes_FG_WithAtoms_NoFGE, + "gn_wall_fgwa_wfge": GN_WithAllNodes_FG_WithAtoms_FGE, } @@ -286,10 +294,6 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: neighbor_type = { G.nodes[nbr].get("node_type") for nbr in G.neighbors(graph_node) } - assert neighbor_type < { - "fg", - "atom", - }, f"Graph node {graph_node} must only connect to one type of node: {neighbor_type}" if "fg" in neighbor_type: graph_pos_arr = np.array( @@ -478,11 +482,15 @@ def plot( - h: Hierarchical 2D-graph with separate plane for each node type - 3d: Hierarchical 3D-graph reader (str): Reader type for graph augmentation. Options: - - 'n_fge_w_gn': FG nodes without FG edges, with a graph node. - - 'w_fge_w_gn': FG nodes with FG edges, with a graph node. + - 'n_fge_w_gn': FG nodes without FG edges, with a graph node connected to fg nodes. + - 'w_fge_w_gn': FG nodes with FG edges, with a graph node connected to fg nodes. - 'w_fge_n_gn': FG nodes with FG edges, no graph node. - 'n_fge_n_gn': FG nodes without FG edges, no graph node. - 'atom_w_gn': Atom nodes only, connected to a graph node. + - 'gnwa_fgwa_nfge': Graph node connected to atoms and FG nodes connected to atoms, no FG edges. + - 'gnwa_fgwa_wfge': Graph node connected to atoms and FG nodes connected to atoms, with FG edges. + - 'gn_wall_fgwa_nfge': Graph node connected to all atoms and FG nodes, no FG edges. + - 'gn_wall_fgwa_wfge': Graph node connected to all atoms and FG nodes, with FG edges. """ fg_reader = READER[reader]() mol = fg_reader._smiles_to_mol(smiles) From 36102f862030d9e99d74d0ed91288f38544beb26 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 3 Aug 2025 15:22:08 +0200 Subject: [PATCH 197/224] move vis script to results --- chebai_graph/preprocessing/utils/__init__.py | 0 .../utils => results}/visualize_augmented_molecule.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chebai_graph/preprocessing/utils/__init__.py rename {chebai_graph/preprocessing/utils => results}/visualize_augmented_molecule.py (100%) diff --git a/chebai_graph/preprocessing/utils/__init__.py b/chebai_graph/preprocessing/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai_graph/preprocessing/utils/visualize_augmented_molecule.py b/results/visualize_augmented_molecule.py similarity index 100% rename from chebai_graph/preprocessing/utils/visualize_augmented_molecule.py rename to results/visualize_augmented_molecule.py From 13384ac2703a69cce12c786521ddfff51568cd51 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 8 Aug 2025 00:00:34 +0200 Subject: [PATCH 198/224] new chebi data class for new readers --- .../preprocessing/datasets/__init__.py | 8 ++++ chebai_graph/preprocessing/datasets/chebi.py | 48 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/__init__.py b/chebai_graph/preprocessing/datasets/__init__.py index c821c5c..d13b1c3 100644 --- a/chebai_graph/preprocessing/datasets/__init__.py +++ b/chebai_graph/preprocessing/datasets/__init__.py @@ -1,5 +1,9 @@ from .chebi import ( ChEBI50_Atom_WGNOnly_GraphProp, + ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE, + ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE, + ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE, + ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE, ChEBI50_NFGE_NGN_GraphProp, ChEBI50_NFGE_WGN_GraphProp, ChEBI50_StaticGNI, @@ -23,4 +27,8 @@ "ChEBI50_WFGE_WGN_GraphProp", "ChEBI50_StaticGNI", "ChEBI50_WFGE_WGN_AsPerNodeType", + "ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE", + "ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE", + "ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE", + "ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE", ] diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 7356600..62c5cf5 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -30,6 +30,10 @@ AtomFGReader_WithFGEdges_WithGraphNode, AtomReader_WithGraphNodeOnly, AtomsFGReader_NoFGEdges_NoGraphNode, + GN_WithAllNodes_FG_WithAtoms_FGE, + GN_WithAllNodes_FG_WithAtoms_NoFGE, + GN_WithAtoms_FG_WithAtoms_FGE, + GN_WithAtoms_FG_WithAtoms_NoFGE, GraphPropertyReader, GraphReader, RandomNodeInitializationReader, @@ -595,6 +599,50 @@ class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): READER = AtomFGReader_WithFGEdges_WithGraphNode +class ChEBI50_GN_WithAllNodes_FG_WithAtoms_FGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + READER = GN_WithAllNodes_FG_WithAtoms_FGE + + +class ChEBI50_GN_WithAllNodes_FG_WithAtoms_NoFGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all nodes (fg + atoms). + """ + + READER = GN_WithAllNodes_FG_WithAtoms_NoFGE + + +class ChEBI50_GN_WithAtoms_FG_WithAtoms_FGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) with functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + READER = GN_WithAtoms_FG_WithAtoms_FGE + + +class ChEBI50_GN_WithAtoms_FG_WithAtoms_NoFGE( + AugGraphPropMixIn_WithGraphNode, ChEBIOver50 +): + """ + ChEBIOver50 with FG nodes (connected to their respective atom nodes) without functional group + edges, and adds a graph-level node connected to all atom nodes. + """ + + READER = GN_WithAtoms_FG_WithAtoms_NoFGE + + class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): """ChEBIOver50 with FG nodes but without FG edges, with graph node.""" From d0ea81d9f67dcf5fa9649c71e47133aff6749f15 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 11 Aug 2025 18:48:53 +0200 Subject: [PATCH 199/224] as node mask for as per node data cls --- chebai_graph/preprocessing/datasets/chebi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 62c5cf5..f3b4ebd 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -510,6 +510,9 @@ def _merge_props_into_base( edge_index=geom_data.edge_index, edge_attr=edge_attr, molecule_attr=torch.empty((1, 0)), # empty as not used for this class + is_atom_node=is_atom_node, + is_fg_node=is_fg_node, + is_graph_node=is_graph_node, ) From fb275210872a5142405af63885a1422a9e836692 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 19 Aug 2025 20:05:50 +0200 Subject: [PATCH 200/224] fix hyperparams for gat and resgated --- configs/model/gat.yml | 4 ++-- configs/model/resgated.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/model/gat.yml b/configs/model/gat.yml index a8cfbe0..d72b4a7 100644 --- a/configs/model/gat.yml +++ b/configs/model/gat.yml @@ -6,10 +6,10 @@ init_args: in_channels: 158 # number of node/atom properties hidden_channels: 256 out_channels: 512 - num_layers: 5 + num_layers: 4 edge_dim: 7 # number of bond properties heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv dropout: 0 n_molecule_properties: 0 - n_linear_layers: 2 + n_linear_layers: 1 diff --git a/configs/model/resgated.yml b/configs/model/resgated.yml index 244f471..ccc6615 100644 --- a/configs/model/resgated.yml +++ b/configs/model/resgated.yml @@ -6,8 +6,8 @@ init_args: in_channels: 158 # number of node/atom properties hidden_channels: 256 out_channels: 512 - num_layers: 5 + num_layers: 4 edge_dim: 7 # number of bond properties dropout: 0 n_molecule_properties: 0 - n_linear_layers: 2 + n_linear_layers: 1 From e09e23cbb97b2b339ca54d984b9a26d844d449e2 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 19 Aug 2025 20:15:26 +0200 Subject: [PATCH 201/224] gat aug node pool class --- chebai_graph/models/__init__.py | 2 ++ chebai_graph/models/augmented.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index cf1796b..e2af0c3 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,4 +1,5 @@ from .augmented import ( + GATAugNodePoolGraphPred, ResGatedAugNodePoolGraphPred, ResGatedAugOnlyPoolGraphPred, ResGatedFGNodeNoGraphNodeGraphPred, @@ -24,4 +25,5 @@ "ResGatedAugOnlyPoolGraphPred", "ResGatedGraphNodeOnlyPoolGraphPred", "ResGatedFGOnlyPoolGraphPred", + "GATAugNodePoolGraphPred", ] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index b1cae5b..22dff80 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -9,6 +9,7 @@ GraphNodeOnlyPoolingNet, GraphNodePoolingNet, ) +from .gat import GATGraphPred from .resgated import ResGatedGraphPred @@ -22,6 +23,16 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): ... +class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred): + """ + Combines: + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes. + - GATGraphPred: Graph attention network for final graph prediction. + """ + + ... + + class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred): """ Combines: From 29c3ea3fa4420c21b22d9668f955afa56b1a3885 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 26 Aug 2025 16:32:29 +0200 Subject: [PATCH 202/224] add fg tokens file --- .../AtomFunctionalGroup/indices_one_hot.txt | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt diff --git a/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt b/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt new file mode 100644 index 0000000..adc6d60 --- /dev/null +++ b/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt @@ -0,0 +1,158 @@ +NO_FG +graph_fg +alkyl +4_ammonium_ion +fluoro +chloro +ether +RING_6 +nitrile +primary_amine +secondary_amine +tertiary_carbon +RING_5 +alkene +RING_7 +tertiary_amine +ketone +hydroxyl +sulfide +sulfonyl +amide +ester +carboxyl +quaternary_carbon +secondary_ketimine +alkyne +alkene_carbon +nitro +trifluoromethyl +RING_4 +secondary_aldimine +primary_ketimine +acetal +carbonate_ester +azo +carbamate +bromo +RING_3 +phosphono +RING_40 +imide +RING_19 +nitroso +RING_8 +sulfhydryl +aldehyde +ketoxime +borono +difluoromethyl +thiolester +aldoxime +carboxylate +thiocyanate +isothiocyanate +phosphate +phosphodiester +iodo +sulfonic_acid +carbodithio +sulfinyl +azide +thioketone +RING_23 +trimethylsilyl +RING_21 +sulfonate_ester +RING_16 +RING_15 +RING_9 +RING_18 +RING_20 +RING_32 +RING_14 +RING_22 +nitrate +primary_aldimine +disulfide +RING_10 +hemiacetal +dichloromethyl +trichloromethyl +haloformyl +hydroperoxy +hemiketal +isonitrile +RING_24 +RING_11 +RING_12 +nitrosooxy +RING_33 +phosphino +sulfino +RING_17 +cyanate +RING_13 +carbothioic_O-acid +RING_36 +RING_72 +hydrazone +RING_31 +RING_25 +RING_28 +RING_35 +RING_34 +RING_42 +RING_38 +RING_29 +RING_26 +RING_51 +RING_75 +RING_27 +RING_45 +RING_60 +RING_30 +RING_43 +RING_41 +RING_56 +RING_81 +RING_39 +RING_37 +RING_47 +RING_49 +RING_44 +phosphoryl +RING_54 +RING_53 +RING_63 +carboxylic_anhydride +carbothioic_S-acid +silyl_ether +borinate +carbodithioic_acid +peroxy +tribromomethyl +chlorobromomethyl +amidine +dibromomethyl +isocyanate +bromodichloromethyl +ketal +boronate +borino +chloroiodomethyl +bromoiodomethyl +thionoester +fluorochloromethyl +RING_68 +bromodifluoromethyl +RING_58 +difluorochloromethyl +thial +RING_52 +RING_48 +RING_62 +RING_164 +RING_71 +RING_46 +orthoester From 4a1857202792b48fdbd6d98b0e9746a026d57867 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 29 Aug 2025 12:59:18 +0200 Subject: [PATCH 203/224] classwise f1 score analysis tutorial --- tutorials/classwise_stats.ipynb | 767 ++++++++++++++++++++++++++++++++ 1 file changed, 767 insertions(+) create mode 100644 tutorials/classwise_stats.ipynb diff --git a/tutorials/classwise_stats.ipynb b/tutorials/classwise_stats.ipynb new file mode 100644 index 0000000..072a128 --- /dev/null +++ b/tutorials/classwise_stats.ipynb @@ -0,0 +1,767 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "id": "07f2eb38", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "eb64bfee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelres_baseline_propsres_wfgn_wfge_wgn_wprops_apoolgatv2_wfgn_wfge_wgn_wprops_apoolgatv2 - baselineres - baseline
01461800.00.00000.00000.00000.0000
1600040.00.11760.00000.00000.1176
2167330.00.23530.00000.00000.2353
3222990.00.28570.00000.00000.2857
4178150.00.12500.11110.11110.1250
\n", + "
" + ], + "text/plain": [ + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0 0.0000 \n", + "1 60004 0.0 0.1176 \n", + "2 16733 0.0 0.2353 \n", + "3 22299 0.0 0.2857 \n", + "4 17815 0.0 0.1250 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \n", + "0 0.0000 0.0000 0.0000 \n", + "1 0.0000 0.0000 0.1176 \n", + "2 0.0000 0.0000 0.2353 \n", + "3 0.0000 0.0000 0.2857 \n", + "4 0.1111 0.1111 0.1250 " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_excel(\"G:\\github-aditya0by0\\python-chebai-graph\\classwise_f1_score.xlsx\")\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "14c10198", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1528, 6)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "32062eb6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['label', 'res_baseline_props', 'res_wfgn_wfge_wgn_wprops_apool',\n", + " 'gatv2_wfgn_wfge_wgn_wprops_apool', 'gatv2 - baseline ',\n", + " 'res - baseline'],\n", + " dtype='object')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5cd9265a", + "metadata": {}, + "outputs": [], + "source": [ + "BASELINE = \"res_baseline_props\"\n", + "RES = \"res_wfgn_wfge_wgn_wprops_apool\"\n", + "GATv2 = \"gatv2_wfgn_wfge_wgn_wprops_apool\"\n", + "GATv2_BASELINE_DIFF = \"gatv2 - baseline \"\n", + "RES_BASELINE_DIFF = \"res - baseline\"" + ] + }, + { + "cell_type": "markdown", + "id": "f679dfee", + "metadata": {}, + "source": [ + "### Summary statistics\n", + "\n", + "Look at the overall performance:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b38fa321", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean improvement over baseline:\n", + " res - baseline 0.054853\n", + "gatv2 - baseline 0.068136\n", + "dtype: float64\n", + "\n", + "Median improvement over baseline:\n", + " res - baseline 0.03330\n", + "gatv2 - baseline 0.04535\n", + "dtype: float64\n", + "\n", + "Variation in improvement:\n", + " res - baseline 0.144831\n", + "gatv2 - baseline 0.139299\n", + "dtype: float64\n" + ] + } + ], + "source": [ + "# mean improvement\n", + "mean_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].mean()\n", + "print(\"Mean improvement over baseline:\\n\", mean_diff)\n", + "print()\n", + "\n", + "# median improvement\n", + "median_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].median()\n", + "print(\"Median improvement over baseline:\\n\", median_diff)\n", + "print()\n", + "\n", + "\n", + "# standard deviation of improvements\n", + "std_diff = df[[RES_BASELINE_DIFF, GATv2_BASELINE_DIFF]].std()\n", + "print(\"Variation in improvement:\\n\", std_diff)" + ] + }, + { + "cell_type": "markdown", + "id": "6359ca34", + "metadata": {}, + "source": [ + "### Count of labels improved / worsened" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "d046db01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Labels improved in ResGated Final model over baseline: 1010\n", + "Labels improved in GATv2 Final model over baseline: 1091\n" + ] + } + ], + "source": [ + "res_improvement = (df[RES_BASELINE_DIFF] > 0).sum()\n", + "gat_v2_improvement = (df[GATv2_BASELINE_DIFF] > 0).sum()\n", + "\n", + "print(f\"Labels improved in ResGated Final model over baseline: {res_improvement}\")\n", + "print(f\"Labels improved in GATv2 Final model over baseline: {gat_v2_improvement}\")" + ] + }, + { + "cell_type": "markdown", + "id": "91dd094e", + "metadata": {}, + "source": [ + "### Ranking per label" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3c3972a1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "best_model\n", + "gatv2_wfgn_wfge_wgn_wprops_apool 678\n", + "res_wfgn_wfge_wgn_wprops_apool 524\n", + "res_baseline_props 326\n", + "Name: count, dtype: int64\n" + ] + } + ], + "source": [ + "df[\"best_model\"] = df[[BASELINE, RES, GATv2]].idxmax(axis=1)\n", + "print(df[\"best_model\"].value_counts())" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "02ecb78e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Top 10 classes with highest improvement for res - baseline:\n", + " 25 0.8889\n", + "24 0.8889\n", + "37 0.7692\n", + "34 0.7143\n", + "35 0.7143\n", + "29 0.7143\n", + "77 0.6889\n", + "41 0.6487\n", + "36 0.6400\n", + "23 0.6364\n", + "Name: res - baseline, dtype: float64\n", + "\n", + "Top 10 classes with highest improvement for gatv2 - baseline :\n", + " 25 0.8889\n", + "24 0.8889\n", + "37 0.7692\n", + "34 0.7143\n", + "35 0.7143\n", + "29 0.7143\n", + "77 0.6889\n", + "41 0.6487\n", + "36 0.6400\n", + "23 0.6364\n", + "Name: res - baseline, dtype: float64\n" + ] + } + ], + "source": [ + "# Classes with highest difference\n", + "top_diff_res = df[RES_BASELINE_DIFF].sort_values(ascending=False).head(10)\n", + "print(\n", + " f\"\\nTop 10 classes with highest improvement for {RES_BASELINE_DIFF}:\\n\",\n", + " top_diff_res,\n", + ")\n", + "\n", + "top_diff_gat = df[GATv2_BASELINE_DIFF].sort_values(ascending=False).head(10)\n", + "print(\n", + " f\"\\nTop 10 classes with highest improvement for {GATv2_BASELINE_DIFF}:\\n\",\n", + " top_diff_res,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "524a01d8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
labelres_baseline_propsres_wfgn_wfge_wgn_wprops_apoolgatv2_wfgn_wfge_wgn_wprops_apoolgatv2 - baselineres - baselinebest_model
01461800.00.00.00000.00000.0res_baseline_props
9257110.00.00.25000.25000.0gatv2_wfgn_wfge_wgn_wprops_apool
10390930.00.00.25000.25000.0gatv2_wfgn_wfge_wgn_wprops_apool
1841940.00.00.46150.46150.0gatv2_wfgn_wfge_wgn_wprops_apool
19176080.00.00.46150.46150.0gatv2_wfgn_wfge_wgn_wprops_apool
........................
1523766291.01.01.00000.00000.0res_baseline_props
1524787761.01.01.00000.00000.0res_baseline_props
1525792041.01.01.00000.00000.0res_baseline_props
15261327421.01.01.00000.00000.0res_baseline_props
15271366371.01.01.00000.00000.0res_baseline_props
\n", + "

87 rows × 7 columns

\n", + "
" + ], + "text/plain": [ + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0 0.0 \n", + "9 25711 0.0 0.0 \n", + "10 39093 0.0 0.0 \n", + "18 4194 0.0 0.0 \n", + "19 17608 0.0 0.0 \n", + "... ... ... ... \n", + "1523 76629 1.0 1.0 \n", + "1524 78776 1.0 1.0 \n", + "1525 79204 1.0 1.0 \n", + "1526 132742 1.0 1.0 \n", + "1527 136637 1.0 1.0 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "0 0.0000 0.0000 0.0 \n", + "9 0.2500 0.2500 0.0 \n", + "10 0.2500 0.2500 0.0 \n", + "18 0.4615 0.4615 0.0 \n", + "19 0.4615 0.4615 0.0 \n", + "... ... ... ... \n", + "1523 1.0000 0.0000 0.0 \n", + "1524 1.0000 0.0000 0.0 \n", + "1525 1.0000 0.0000 0.0 \n", + "1526 1.0000 0.0000 0.0 \n", + "1527 1.0000 0.0000 0.0 \n", + "\n", + " best_model \n", + "0 res_baseline_props \n", + "9 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "10 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "18 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "19 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "... ... \n", + "1523 res_baseline_props \n", + "1524 res_baseline_props \n", + "1525 res_baseline_props \n", + "1526 res_baseline_props \n", + "1527 res_baseline_props \n", + "\n", + "[87 rows x 7 columns]" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Classes with minimal or no difference\n", + "minimal_diff_threshold = 0.01 # you can adjust\n", + "minimal_diff = df[df[RES_BASELINE_DIFF].abs() <= minimal_diff_threshold]\n", + "minimal_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "5f15e872", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where baseline performs better than our model:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "45 59999 0.1176 0.0541 \n", + "47 59266 0.1176 0.1143 \n", + "50 25095 0.1333 0.0000 \n", + "70 48927 0.2000 0.0000 \n", + "72 83943 0.2000 0.0000 \n", + "... ... ... ... \n", + "1508 24846 1.0000 0.9333 \n", + "1509 26187 1.0000 0.9333 \n", + "1510 78231 1.0000 0.9333 \n", + "1511 26377 1.0000 0.9565 \n", + "1512 38834 1.0000 0.9677 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "45 0.1250 0.0074 -0.0635 \n", + "47 0.5385 0.4209 -0.0033 \n", + "50 0.2667 0.1334 -0.1333 \n", + "70 0.0000 -0.2000 -0.2000 \n", + "72 0.2000 0.0000 -0.2000 \n", + "... ... ... ... \n", + "1508 1.0000 0.0000 -0.0667 \n", + "1509 1.0000 0.0000 -0.0667 \n", + "1510 1.0000 0.0000 -0.0667 \n", + "1511 1.0000 0.0000 -0.0435 \n", + "1512 1.0000 0.0000 -0.0323 \n", + "\n", + " best_model \n", + "45 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "47 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "50 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "70 res_baseline_props \n", + "72 res_baseline_props \n", + "... ... \n", + "1508 res_baseline_props \n", + "1509 res_baseline_props \n", + "1510 res_baseline_props \n", + "1511 res_baseline_props \n", + "1512 res_baseline_props \n", + "\n", + "[431 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where baseline performs better\n", + "baseline_better = df[df[RES_BASELINE_DIFF] < 0]\n", + "print(\"\\nClasses where baseline performs better than our model:\\n\", baseline_better)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "0c4292bf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where both baseline and res_wfgn_wfge_wgn_wprops_apool have F1 < 0.5:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "0 146180 0.0000 0.0000 \n", + "1 60004 0.0000 0.1176 \n", + "2 16733 0.0000 0.2353 \n", + "3 22299 0.0000 0.2857 \n", + "4 17815 0.0000 0.1250 \n", + ".. ... ... ... \n", + "265 32877 0.4615 0.4706 \n", + "271 139358 0.4667 0.4000 \n", + "283 35346 0.4762 0.4706 \n", + "284 35507 0.4762 0.4737 \n", + "295 24586 0.4878 0.4762 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "0 0.0000 0.0000 0.0000 \n", + "1 0.0000 0.0000 0.1176 \n", + "2 0.0000 0.0000 0.2353 \n", + "3 0.0000 0.0000 0.2857 \n", + "4 0.1111 0.1111 0.1250 \n", + ".. ... ... ... \n", + "265 0.5333 0.0718 0.0091 \n", + "271 0.5517 0.0850 -0.0667 \n", + "283 0.3529 -0.1233 -0.0056 \n", + "284 0.3871 -0.0891 -0.0025 \n", + "295 0.6667 0.1789 -0.0116 \n", + "\n", + " best_model \n", + "0 res_baseline_props \n", + "1 res_wfgn_wfge_wgn_wprops_apool \n", + "2 res_wfgn_wfge_wgn_wprops_apool \n", + "3 res_wfgn_wfge_wgn_wprops_apool \n", + "4 res_wfgn_wfge_wgn_wprops_apool \n", + ".. ... \n", + "265 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "271 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "283 res_baseline_props \n", + "284 res_baseline_props \n", + "295 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "\n", + "[148 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where both baseline and our model perform worse\n", + "# Assuming \"worse\" means below a threshold, e.g., f1 < 0.5\n", + "threshold = 0.5\n", + "both_worse = df[(df[BASELINE] < threshold) & (df[RES] < threshold)]\n", + "print(f\"\\nClasses where both baseline and {RES} have F1 < {threshold}:\\n\", both_worse)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "8948b7bb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Classes where baseline performs better than our model:\n", + " label res_baseline_props res_wfgn_wfge_wgn_wprops_apool \\\n", + "45 59999 0.1176 0.0541 \n", + "47 59266 0.1176 0.1143 \n", + "50 25095 0.1333 0.0000 \n", + "70 48927 0.2000 0.0000 \n", + "72 83943 0.2000 0.0000 \n", + "... ... ... ... \n", + "1508 24846 1.0000 0.9333 \n", + "1509 26187 1.0000 0.9333 \n", + "1510 78231 1.0000 0.9333 \n", + "1511 26377 1.0000 0.9565 \n", + "1512 38834 1.0000 0.9677 \n", + "\n", + " gatv2_wfgn_wfge_wgn_wprops_apool gatv2 - baseline res - baseline \\\n", + "45 0.1250 0.0074 -0.0635 \n", + "47 0.5385 0.4209 -0.0033 \n", + "50 0.2667 0.1334 -0.1333 \n", + "70 0.0000 -0.2000 -0.2000 \n", + "72 0.2000 0.0000 -0.2000 \n", + "... ... ... ... \n", + "1508 1.0000 0.0000 -0.0667 \n", + "1509 1.0000 0.0000 -0.0667 \n", + "1510 1.0000 0.0000 -0.0667 \n", + "1511 1.0000 0.0000 -0.0435 \n", + "1512 1.0000 0.0000 -0.0323 \n", + "\n", + " best_model \n", + "45 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "47 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "50 gatv2_wfgn_wfge_wgn_wprops_apool \n", + "70 res_baseline_props \n", + "72 res_baseline_props \n", + "... ... \n", + "1508 res_baseline_props \n", + "1509 res_baseline_props \n", + "1510 res_baseline_props \n", + "1511 res_baseline_props \n", + "1512 res_baseline_props \n", + "\n", + "[431 rows x 7 columns]\n" + ] + } + ], + "source": [ + "# Classes where baseline performs better\n", + "baseline_better = df[df[RES_BASELINE_DIFF] < 0]\n", + "print(\"\\nClasses where baseline performs better than our model:\\n\", baseline_better)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gnn3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 0bd59565178634b01feff9e705ab774f008ae1d4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 4 Oct 2025 11:34:22 +0200 Subject: [PATCH 204/224] gatv2-constrainted --- chebai_graph/models/gat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index 230d3c5..acd6fe8 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -28,6 +28,7 @@ def __init__(self, config: dict, **kwargs): super().__init__(config=config, **kwargs) self.heads = int(config["heads"]) self.v2 = bool(config["v2"]) + self.share_weights = bool(config.get("share_weights", False)) self.activation = ELU() # Instantiate ELU once for reuse. self.gat = GAT( in_channels=self.in_channels, @@ -39,6 +40,7 @@ def __init__(self, config: dict, **kwargs): heads=self.heads, v2=self.v2, act=self.activation, + share_weights=self.share_weights ) def forward(self, batch: dict) -> torch.Tensor: From 0653b4f2ea868481768d458887e22aecf197b3ea Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 4 Oct 2025 12:07:51 +0200 Subject: [PATCH 205/224] GATv2 amg pool --- chebai_graph/models/__init__.py | 18 ++----- chebai_graph/models/augmented.py | 88 +++----------------------------- chebai_graph/models/base.py | 2 +- chebai_graph/models/gat.py | 2 +- 4 files changed, 12 insertions(+), 98 deletions(-) diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index e2af0c3..f595cf6 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -1,29 +1,17 @@ from .augmented import ( GATAugNodePoolGraphPred, + GATGraphNodeFGNodePoolGraphPred, ResGatedAugNodePoolGraphPred, - ResGatedAugOnlyPoolGraphPred, - ResGatedFGNodeNoGraphNodeGraphPred, - ResGatedFGNodePoolGraphPred, - ResGatedFGOnlyPoolGraphPred, ResGatedGraphNodeFGNodePoolGraphPred, - ResGatedGraphNodeNoFGNodeGraphPred, - ResGatedGraphNodeOnlyPoolGraphPred, - ResGatedGraphNodePoolGraphPred, ) from .gat import GATGraphPred from .resgated import ResGatedGraphPred __all__ = [ - "GATGraphPred", "ResGatedGraphPred", - "ResGatedFGNodeNoGraphNodeGraphPred", "ResGatedAugNodePoolGraphPred", "ResGatedGraphNodeFGNodePoolGraphPred", - "ResGatedGraphNodePoolGraphPred", - "ResGatedGraphNodeNoFGNodeGraphPred", - "ResGatedFGNodePoolGraphPred", - "ResGatedAugOnlyPoolGraphPred", - "ResGatedGraphNodeOnlyPoolGraphPred", - "ResGatedFGOnlyPoolGraphPred", + "GATGraphPred", "GATAugNodePoolGraphPred", + "GATGraphNodeFGNodePoolGraphPred", ] diff --git a/chebai_graph/models/augmented.py b/chebai_graph/models/augmented.py index 22dff80..fdb5388 100644 --- a/chebai_graph/models/augmented.py +++ b/chebai_graph/models/augmented.py @@ -1,14 +1,4 @@ -from .base import ( - AugmentedNodePoolingNet, - AugmentedOnlyPoolingNet, - FGNodePoolingNet, - FGNodePoolingNoGraphNodeNet, - FGOnlyPoolingNet, - GraphNodeFGNodePoolingNet, - GraphNodeNoFGNodePoolingNet, - GraphNodeOnlyPoolingNet, - GraphNodePoolingNet, -) +from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet from .gat import GATGraphPred from .resgated import ResGatedGraphPred @@ -16,7 +6,7 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): """ Combines: - - AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes. + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes). - ResGatedGraphPred: Residual gated network for final graph prediction. """ @@ -26,94 +16,30 @@ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred): class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred): """ Combines: - - AugmentedNodePoolingNet: Pools atom and augmented node embeddings with molecule attributes. + - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes). - GATGraphPred: Graph attention network for final graph prediction. """ ... -class ResGatedGraphNodePoolGraphPred(GraphNodePoolingNet, ResGatedGraphPred): - """ - Combines: - - GraphNodePoolingNet: Pools atom and graph node embeddings with molecule attributes. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - -class ResGatedFGNodePoolGraphPred(FGNodePoolingNet, ResGatedGraphPred): - """ - Combines: - - FGNodePoolingNet: Pools functional group nodes and other nodes with molecule attributes. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - class ResGatedGraphNodeFGNodePoolGraphPred( GraphNodeFGNodePoolingNet, ResGatedGraphPred ): """ Combines: - - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes with molecule attributes. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - -class ResGatedGraphNodeNoFGNodeGraphPred( - GraphNodeNoFGNodePoolingNet, ResGatedGraphPred -): - """ - Combines: - - GraphNodeNoFGNodePoolingNet: Pools atom and graph nodes, excluding functional groups. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - -class ResGatedFGNodeNoGraphNodeGraphPred( - FGNodePoolingNoGraphNodeNet, ResGatedGraphPred -): - """ - Combines: - - FGNodePoolingNoGraphNodeNet: Pools atom and functional group nodes, excluding graph nodes. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - -class ResGatedAugOnlyPoolGraphPred(AugmentedOnlyPoolingNet, ResGatedGraphPred): - """ - Combines: - - AugmentedOnlyPoolingNet: Pools only augmented nodes with molecule attributes. + - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes). - ResGatedGraphPred: Residual gated network for final graph prediction. """ ... -class ResGatedGraphNodeOnlyPoolGraphPred(GraphNodeOnlyPoolingNet, ResGatedGraphPred): +class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred): """ Combines: - - GraphNodeOnlyPoolingNet: Pools only graph nodes with molecule attributes. - - ResGatedGraphPred: Residual gated network for final graph prediction. - """ - - ... - - -class ResGatedFGOnlyPoolGraphPred(FGOnlyPoolingNet, ResGatedGraphPred): - """ - Combines: - - FGOnlyPoolingNet: Pools only functional group nodes with molecule attributes. - - ResGatedGraphPred: Residual gated network for final graph prediction. + - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes). + - GATGraphPred: Graph attention network for final graph prediction. """ ... diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 1ddc784..952fd87 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -200,7 +200,7 @@ def _get_lin_seq_input_dim( Includes: - Atom embeddings - - Molecular attributes + - Molecular attributes (if any) - Augmented node embeddings Args: diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index acd6fe8..45d5ed6 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -40,7 +40,7 @@ def __init__(self, config: dict, **kwargs): heads=self.heads, v2=self.v2, act=self.activation, - share_weights=self.share_weights + share_weights=self.share_weights, ) def forward(self, batch: dict) -> torch.Tensor: From c6b9ff2516f215c445e11a58ec5d3d7c63ac86e2 Mon Sep 17 00:00:00 2001 From: sifluegel Date: Mon, 6 Oct 2025 10:29:49 +0200 Subject: [PATCH 206/224] add chebi100 config --- configs/data/chebi100_graph_properties.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 configs/data/chebi100_graph_properties.yml diff --git a/configs/data/chebi100_graph_properties.yml b/configs/data/chebi100_graph_properties.yml new file mode 100644 index 0000000..a413968 --- /dev/null +++ b/configs/data/chebi100_graph_properties.yml @@ -0,0 +1,14 @@ +class_path: chebai_graph.preprocessing.datasets.ChEBI100GraphProperties +init_args: + properties: + - chebai_graph.preprocessing.properties.AtomType + - chebai_graph.preprocessing.properties.NumAtomBonds + - chebai_graph.preprocessing.properties.AtomCharge + - chebai_graph.preprocessing.properties.AtomAromaticity + - chebai_graph.preprocessing.properties.AtomHybridization + - chebai_graph.preprocessing.properties.AtomNumHs + - chebai_graph.preprocessing.properties.BondType + - chebai_graph.preprocessing.properties.BondInRing + - chebai_graph.preprocessing.properties.BondAromaticity + #- chebai_graph.preprocessing.properties.MoleculeNumRings + - chebai_graph.preprocessing.properties.RDKit2DNormalized From 7626dad7074dd76f5a3dd8ddcfe3dabb5a6a14a6 Mon Sep 17 00:00:00 2001 From: sifluegel Date: Mon, 6 Oct 2025 10:37:14 +0200 Subject: [PATCH 207/224] fix chebi100 config --- configs/data/chebi100_graph_properties.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/data/chebi100_graph_properties.yml b/configs/data/chebi100_graph_properties.yml index a413968..7e78711 100644 --- a/configs/data/chebi100_graph_properties.yml +++ b/configs/data/chebi100_graph_properties.yml @@ -1,4 +1,4 @@ -class_path: chebai_graph.preprocessing.datasets.ChEBI100GraphProperties +class_path: chebai_graph.preprocessing.datasets.chebi.ChEBI100GraphProperties init_args: properties: - chebai_graph.preprocessing.properties.AtomType From 3b38afada251aba8613c14c41e4fe5063d74585d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 13:25:27 +0200 Subject: [PATCH 208/224] static gni no mol props --- .../preprocessing/reader/static_gni.py | 18 +++++++++--------- configs/data/chebi50_static_gni.yml | 2 -- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 68a4e8e..c045942 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -17,15 +17,15 @@ def __init__( self, num_node_properties: int, num_bond_properties: int, - num_molecule_properties: int, - distribution: str, + # num_molecule_properties: int, + distribution: str = "normal", *args, **kwargs, ): super().__init__(*args, **kwargs) self.num_node_properties = num_node_properties self.num_bond_properties = num_bond_properties - self.num_molecule_properties = num_molecule_properties + # self.num_molecule_properties = num_molecule_properties assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] self.distribution = distribution @@ -44,30 +44,30 @@ def _read_data(self, raw_data): random_edge_attr = torch.empty( data.edge_index.shape[1], self.num_bond_properties ) - random_molecule_properties = torch.empty(1, self.num_molecule_properties) + # random_molecule_properties = torch.empty(1, self.num_molecule_properties) if self.distribution == "normal": torch.nn.init.normal_(random_x) torch.nn.init.normal_(random_edge_attr) - torch.nn.init.normal_(random_molecule_properties) + # torch.nn.init.normal_(random_molecule_properties) elif self.distribution == "uniform": torch.nn.init.uniform_(random_x, a=-1.0, b=1.0) torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0) - torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) + # torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) elif self.distribution == "xavier_normal": torch.nn.init.xavier_normal_(random_x) torch.nn.init.xavier_normal_(random_edge_attr) - torch.nn.init.xavier_normal_(random_molecule_properties) + # torch.nn.init.xavier_normal_(random_molecule_properties) elif self.distribution == "xavier_uniform": torch.nn.init.xavier_uniform_(random_x) torch.nn.init.xavier_uniform_(random_edge_attr) - torch.nn.init.xavier_uniform_(random_molecule_properties) + # torch.nn.init.xavier_uniform_(random_molecule_properties) else: raise ValueError("Unknown distribution type") data.x = random_x data.edge_attr = random_edge_attr - data.molecule_attr = random_molecule_properties + # data.molecule_attr = random_molecule_properties return data def read_property(self, *args, **kwargs) -> Exception: diff --git a/configs/data/chebi50_static_gni.yml b/configs/data/chebi50_static_gni.yml index 1509802..5130005 100644 --- a/configs/data/chebi50_static_gni.yml +++ b/configs/data/chebi50_static_gni.yml @@ -3,5 +3,3 @@ init_args: reader_kwargs: num_node_properties: 158 num_bond_properties: 7 - num_molecule_properties: 200 - distribution: normal From 2757ed30f6809d3462cda3811b272fbc20ab275f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 16:29:55 +0200 Subject: [PATCH 209/224] fix gat v2 share weights issue --- chebai_graph/models/gat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/gat.py b/chebai_graph/models/gat.py index 45d5ed6..13dcabc 100644 --- a/chebai_graph/models/gat.py +++ b/chebai_graph/models/gat.py @@ -28,7 +28,9 @@ def __init__(self, config: dict, **kwargs): super().__init__(config=config, **kwargs) self.heads = int(config["heads"]) self.v2 = bool(config["v2"]) - self.share_weights = bool(config.get("share_weights", False)) + local_kwargs = {} + if self.v2: + local_kwargs["share_weights"] = bool(config.get("share_weights", False)) self.activation = ELU() # Instantiate ELU once for reuse. self.gat = GAT( in_channels=self.in_channels, @@ -40,7 +42,7 @@ def __init__(self, config: dict, **kwargs): heads=self.heads, v2=self.v2, act=self.activation, - share_weights=self.share_weights, + **local_kwargs, ) def forward(self, batch: dict) -> torch.Tensor: From a0d6ea7701081fd6025f671f40fe5298a4ff65e5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 17:13:20 +0200 Subject: [PATCH 210/224] revert mol props --- chebai_graph/preprocessing/reader/static_gni.py | 16 ++++++++-------- configs/data/chebi50_static_gni.yml | 1 + 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index c045942..c77e429 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -17,7 +17,7 @@ def __init__( self, num_node_properties: int, num_bond_properties: int, - # num_molecule_properties: int, + num_molecule_properties: int, distribution: str = "normal", *args, **kwargs, @@ -25,7 +25,7 @@ def __init__( super().__init__(*args, **kwargs) self.num_node_properties = num_node_properties self.num_bond_properties = num_bond_properties - # self.num_molecule_properties = num_molecule_properties + self.num_molecule_properties = num_molecule_properties assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] self.distribution = distribution @@ -44,30 +44,30 @@ def _read_data(self, raw_data): random_edge_attr = torch.empty( data.edge_index.shape[1], self.num_bond_properties ) - # random_molecule_properties = torch.empty(1, self.num_molecule_properties) + random_molecule_properties = torch.empty(1, self.num_molecule_properties) if self.distribution == "normal": torch.nn.init.normal_(random_x) torch.nn.init.normal_(random_edge_attr) - # torch.nn.init.normal_(random_molecule_properties) + torch.nn.init.normal_(random_molecule_properties) elif self.distribution == "uniform": torch.nn.init.uniform_(random_x, a=-1.0, b=1.0) torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0) - # torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) + torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) elif self.distribution == "xavier_normal": torch.nn.init.xavier_normal_(random_x) torch.nn.init.xavier_normal_(random_edge_attr) - # torch.nn.init.xavier_normal_(random_molecule_properties) + torch.nn.init.xavier_normal_(random_molecule_properties) elif self.distribution == "xavier_uniform": torch.nn.init.xavier_uniform_(random_x) torch.nn.init.xavier_uniform_(random_edge_attr) - # torch.nn.init.xavier_uniform_(random_molecule_properties) + torch.nn.init.xavier_uniform_(random_molecule_properties) else: raise ValueError("Unknown distribution type") data.x = random_x data.edge_attr = random_edge_attr - # data.molecule_attr = random_molecule_properties + data.molecule_attr = random_molecule_properties return data def read_property(self, *args, **kwargs) -> Exception: diff --git a/configs/data/chebi50_static_gni.yml b/configs/data/chebi50_static_gni.yml index 5130005..12096cb 100644 --- a/configs/data/chebi50_static_gni.yml +++ b/configs/data/chebi50_static_gni.yml @@ -3,3 +3,4 @@ init_args: reader_kwargs: num_node_properties: 158 num_bond_properties: 7 + num_molecule_properties: 0 From 79fc5007c7d042a16df797d7691dc8cf9bb9475b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 19:40:44 +0200 Subject: [PATCH 211/224] dynamic gni --- chebai_graph/models/__init__.py | 2 + chebai_graph/models/dynamic_gni.py | 89 +++++++++++++++++++ chebai_graph/preprocessing/datasets/chebi.py | 4 +- chebai_graph/preprocessing/reader/__init__.py | 4 +- .../preprocessing/reader/static_gni.py | 36 ++++---- configs/model/resgated_dynamic_gni.yml | 13 +++ 6 files changed, 125 insertions(+), 23 deletions(-) create mode 100644 chebai_graph/models/dynamic_gni.py create mode 100644 configs/model/resgated_dynamic_gni.yml diff --git a/chebai_graph/models/__init__.py b/chebai_graph/models/__init__.py index f595cf6..9e20b2d 100644 --- a/chebai_graph/models/__init__.py +++ b/chebai_graph/models/__init__.py @@ -4,6 +4,7 @@ ResGatedAugNodePoolGraphPred, ResGatedGraphNodeFGNodePoolGraphPred, ) +from .dynamic_gni import ResGatedDynamicGNIGraphPred from .gat import GATGraphPred from .resgated import ResGatedGraphPred @@ -14,4 +15,5 @@ "GATGraphPred", "GATAugNodePoolGraphPred", "GATGraphNodeFGNodePoolGraphPred", + "ResGatedDynamicGNIGraphPred", ] diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py new file mode 100644 index 0000000..26b9a88 --- /dev/null +++ b/chebai_graph/models/dynamic_gni.py @@ -0,0 +1,89 @@ +from typing import Any + +import torch +from torch import Tensor +from torch.nn import ELU +from torch_geometric.data import Data as GraphData +from torch_geometric.nn.models.basic_gnn import BasicGNN + +from chebai_graph.preprocessing.reader import RandomFeatureInitializationReader + +from .base import GraphModelBase, GraphNetWrapper +from .resgated import ResGatedModel + + +class ResGatedDynamicGNI(GraphModelBase): + """ + Base model class for applying ResGatedGraphConv layers to graph-structured data + with dynamic initialization of features for nodes and edges. + + Args: + config (dict): Configuration dictionary containing model hyperparameters. + **kwargs: Additional keyword arguments for parent class. + """ + + def __init__(self, config: dict[str, Any], **kwargs: Any): + super().__init__(config=config, **kwargs) + self.activation = ELU() # Instantiate ELU once for reuse. + distribution = config.get("distribution", "normal") + assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] + self.distribution = distribution + + self.resgated: BasicGNN = ResGatedModel( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.out_channels, + num_layers=self.num_layers, + edge_dim=self.edge_dim, + act=self.activation, + ) + + def forward(self, batch: dict[str, Any]) -> Tensor: + """ + Forward pass of the model. + + Args: + batch (dict): A batch containing graph input features under the key "features". + + Returns: + Tensor: The output node-level embeddings after the final activation. + """ + graph_data = batch["features"][0] + assert isinstance(graph_data, GraphData), "Expected GraphData instance" + + random_x = torch.empty(graph_data.x.shape[0], graph_data.x.shape[1]) + RandomFeatureInitializationReader.random_gni(random_x, self.distribution) + random_edge_attr = torch.empty( + graph_data.edge_attr.shape[0], graph_data.edge_attr.shape[1] + ) + RandomFeatureInitializationReader.random_gni( + random_edge_attr, self.distribution + ) + + out = self.resgated( + x=graph_data.x.float(), + edge_index=graph_data.edge_index.long(), + edge_attr=graph_data.edge_attr, + ) + + return self.activation(out) + + +class ResGatedDynamicGNIGraphPred(GraphNetWrapper): + """ + Wrapper for graph-level prediction using ResGatedDynamicGNI. + + This class instantiates the core GNN model using the provided config. + """ + + def _get_gnn(self, config: dict[str, Any]) -> ResGatedDynamicGNI: + """ + Returns the core ResGated GNN model. + + Args: + config (dict): Configuration dictionary for the GNN model. + + Returns: + ResGatedDynamicGNI: The core graph convolutional network. + """ + return ResGatedDynamicGNI(config=config) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 06a8bac..27e27e7 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -36,7 +36,7 @@ GN_WithAtoms_FG_WithAtoms_NoFGE, GraphPropertyReader, GraphReader, - RandomNodeInitializationReader, + RandomFeatureInitializationReader, ) from .utils import resolve_property @@ -518,7 +518,7 @@ def _merge_props_into_base( class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): - READER = RandomNodeInitializationReader + READER = RandomFeatureInitializationReader def _setup_properties(self): ... diff --git a/chebai_graph/preprocessing/reader/__init__.py b/chebai_graph/preprocessing/reader/__init__.py index ee3fce0..3bc71ed 100644 --- a/chebai_graph/preprocessing/reader/__init__.py +++ b/chebai_graph/preprocessing/reader/__init__.py @@ -10,7 +10,7 @@ GN_WithAtoms_FG_WithAtoms_NoFGE, ) from .reader import GraphPropertyReader, GraphReader -from .static_gni import RandomNodeInitializationReader +from .static_gni import RandomFeatureInitializationReader __all__ = [ "GraphReader", @@ -20,7 +20,7 @@ "AtomFGReader_NoFGEdges_WithGraphNode", "AtomFGReader_WithFGEdges_NoGraphNode", "AtomFGReader_WithFGEdges_WithGraphNode", - "RandomNodeInitializationReader", + "RandomFeatureInitializationReader", "GN_WithAtoms_FG_WithAtoms_FGE", "GN_WithAtoms_FG_WithAtoms_NoFGE", "GN_WithAllNodes_FG_WithAtoms_FGE", diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index c77e429..0084b9d 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -12,7 +12,7 @@ from .reader import GraphPropertyReader -class RandomNodeInitializationReader(GraphPropertyReader): +class RandomFeatureInitializationReader(GraphPropertyReader): def __init__( self, num_node_properties: int, @@ -46,24 +46,9 @@ def _read_data(self, raw_data): ) random_molecule_properties = torch.empty(1, self.num_molecule_properties) - if self.distribution == "normal": - torch.nn.init.normal_(random_x) - torch.nn.init.normal_(random_edge_attr) - torch.nn.init.normal_(random_molecule_properties) - elif self.distribution == "uniform": - torch.nn.init.uniform_(random_x, a=-1.0, b=1.0) - torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0) - torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0) - elif self.distribution == "xavier_normal": - torch.nn.init.xavier_normal_(random_x) - torch.nn.init.xavier_normal_(random_edge_attr) - torch.nn.init.xavier_normal_(random_molecule_properties) - elif self.distribution == "xavier_uniform": - torch.nn.init.xavier_uniform_(random_x) - torch.nn.init.xavier_uniform_(random_edge_attr) - torch.nn.init.xavier_uniform_(random_molecule_properties) - else: - raise ValueError("Unknown distribution type") + self.random_gni(random_x, self.distribution) + self.random_gni(random_edge_attr, self.distribution) + self.random_gni(random_molecule_properties, self.distribution) data.x = random_x data.edge_attr = random_edge_attr @@ -73,3 +58,16 @@ def _read_data(self, raw_data): def read_property(self, *args, **kwargs) -> Exception: """This reader does not support reading specific properties.""" raise NotImplementedError("This reader only performs random initialization.") + + @staticmethod + def random_gni(tensor: torch.Tensor, distribution: str) -> None: + if distribution == "normal": + torch.nn.init.normal_(tensor) + elif distribution == "uniform": + torch.nn.init.uniform_(tensor, a=-1.0, b=1.0) + elif distribution == "xavier_normal": + torch.nn.init.xavier_normal_(tensor) + elif distribution == "xavier_uniform": + torch.nn.init.xavier_uniform_(tensor) + else: + raise ValueError("Unknown distribution type") diff --git a/configs/model/resgated_dynamic_gni.yml b/configs/model/resgated_dynamic_gni.yml new file mode 100644 index 0000000..4749795 --- /dev/null +++ b/configs/model/resgated_dynamic_gni.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedDynamicGNIGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 158 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 7 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 From d2fe1d2f7cdfecc0411d2c528e0a6a9eb7972d64 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 23:01:44 +0200 Subject: [PATCH 212/224] fix reader none error --- chebai_graph/preprocessing/reader/static_gni.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 0084b9d..347f6d5 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -40,9 +40,12 @@ def name(self) -> str: def _read_data(self, raw_data): data: GeomData = super()._read_data(raw_data) + if data is None: + return None + random_x = torch.empty(data.x.shape[0], self.num_node_properties) random_edge_attr = torch.empty( - data.edge_index.shape[1], self.num_bond_properties + data.edge_attr.shape[0], self.num_bond_properties ) random_molecule_properties = torch.empty(1, self.num_molecule_properties) From 69cf0a59d57644a05d0ddf558687498becdab7c3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 6 Oct 2025 23:17:05 +0200 Subject: [PATCH 213/224] dynamic gni fix --- chebai_graph/models/dynamic_gni.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 26b9a88..5e7ffe6 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -61,9 +61,9 @@ def forward(self, batch: dict[str, Any]) -> Tensor: ) out = self.resgated( - x=graph_data.x.float(), + x=random_x.float(), edge_index=graph_data.edge_index.long(), - edge_attr=graph_data.edge_attr, + edge_attr=random_edge_attr.float(), ) return self.activation(out) From 09e9227d8709b4532bcc71b257a284b21827c269 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 7 Oct 2025 10:58:52 +0200 Subject: [PATCH 214/224] allow not setting n_molecule_properties --- chebai_graph/models/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 1ddc784..58fff35 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional import torch from chebai.models.base import ChebaiBaseNet @@ -73,7 +74,7 @@ class GraphNetWrapper(GraphBaseNet, ABC): """ def __init__( - self, config: dict, n_linear_layers: int, n_molecule_properties: int, **kwargs + self, config: dict, n_linear_layers: int, n_molecule_properties: Optional[int] = 0, **kwargs ) -> None: """ Initialize the GNN and linear layers. @@ -90,7 +91,7 @@ def __init__( self.activation = torch.nn.ELU self.lin_input_dim = self._get_lin_seq_input_dim( gnn_out_dim=gnn_out_dim, - n_molecule_properties=n_molecule_properties, + n_molecule_properties=n_molecule_properties if n_molecule_properties is not None else 0, ) lin_hidden_dim = kwargs.get("lin_hidden_dim", gnn_out_dim) From e304f150dd41780ce551edc8ef0d730bc1414ea9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 7 Oct 2025 13:21:15 +0200 Subject: [PATCH 215/224] add configs for aug --- .gitignore | 5 +++++ configs/data/chebi50_aug_all_props.yml | 20 -------------------- configs/data/chebi50_aug_props_only.yml | 11 ----------- configs/model/gat_aug_aapool.yml | 15 +++++++++++++++ configs/model/gat_aug_amgpool.yml | 15 +++++++++++++++ configs/model/res_aug_aapool.yml | 13 +++++++++++++ configs/model/res_aug_amgpool.yml | 13 +++++++++++++ 7 files changed, 61 insertions(+), 31 deletions(-) delete mode 100644 configs/data/chebi50_aug_all_props.yml delete mode 100644 configs/data/chebi50_aug_props_only.yml create mode 100644 configs/model/gat_aug_aapool.yml create mode 100644 configs/model/gat_aug_amgpool.yml create mode 100644 configs/model/res_aug_aapool.yml create mode 100644 configs/model/res_aug_amgpool.yml diff --git a/.gitignore b/.gitignore index 3f0c0ab..d01eecf 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,8 @@ cython_debug/ electra_pretrained.ckpt .isort.cfg /.vscode + +*.err +*.out +*.sh +*.ckpt diff --git a/configs/data/chebi50_aug_all_props.yml b/configs/data/chebi50_aug_all_props.yml deleted file mode 100644 index 47e75ad..0000000 --- a/configs/data/chebi50_aug_all_props.yml +++ /dev/null @@ -1,20 +0,0 @@ -class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp -init_args: - properties: - # Atom properties - - chebai_graph.preprocessing.properties.AtomNodeLevel - - chebai_graph.preprocessing.properties.AugAtomType - - chebai_graph.preprocessing.properties.AugNumAtomBonds - - chebai_graph.preprocessing.properties.AugAtomCharge - - chebai_graph.preprocessing.properties.AugAtomAromaticity - - chebai_graph.preprocessing.properties.AugAtomHybridization - - chebai_graph.preprocessing.properties.AugAtomNumHs - - chebai_graph.preprocessing.properties.AtomFunctionalGroup - - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG - - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - - chebai_graph.preprocessing.properties.IsFGAlkyl - # Bond properties - - chebai_graph.preprocessing.properties.BondLevel - - chebai_graph.preprocessing.properties.AugBondType - - chebai_graph.preprocessing.properties.AugBondInRing - - chebai_graph.preprocessing.properties.AugBondAromaticity diff --git a/configs/data/chebi50_aug_props_only.yml b/configs/data/chebi50_aug_props_only.yml deleted file mode 100644 index d81d303..0000000 --- a/configs/data/chebi50_aug_props_only.yml +++ /dev/null @@ -1,11 +0,0 @@ -class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp -init_args: - properties: - # Atom properties - - chebai_graph.preprocessing.properties.AtomFunctionalGroup - - chebai_graph.preprocessing.properties.AtomNodeLevel - - chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG - - chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG - - chebai_graph.preprocessing.properties.IsFGAlkyl - # Bond properties - - chebai_graph.preprocessing.properties.BondLevel diff --git a/configs/model/gat_aug_aapool.yml b/configs/model/gat_aug_aapool.yml new file mode 100644 index 0000000..05ef4e8 --- /dev/null +++ b/configs/model/gat_aug_aapool.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATAugNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + v2: False # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/gat_aug_amgpool.yml b/configs/model/gat_aug_amgpool.yml new file mode 100644 index 0000000..b5ce89d --- /dev/null +++ b/configs/model/gat_aug_amgpool.yml @@ -0,0 +1,15 @@ +class_path: chebai_graph.models.GATGraphNodeFGNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + heads: 8 # the number of heads should be divisible by output channels (hidden channels if output channel not given) + v2: True # set True to use `torch_geometric.nn.conv.GATv2Conv` convolution layers, default is GATConv + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/res_aug_aapool.yml b/configs/model/res_aug_aapool.yml new file mode 100644 index 0000000..9e364f9 --- /dev/null +++ b/configs/model/res_aug_aapool.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedAugNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 diff --git a/configs/model/res_aug_amgpool.yml b/configs/model/res_aug_amgpool.yml new file mode 100644 index 0000000..2aba5ea --- /dev/null +++ b/configs/model/res_aug_amgpool.yml @@ -0,0 +1,13 @@ +class_path: chebai_graph.models.ResGatedGraphNodeFGNodePoolGraphPred +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + in_channels: 203 # number of node/atom properties + hidden_channels: 256 + out_channels: 512 + num_layers: 4 + edge_dim: 11 # number of bond properties + dropout: 0 + n_molecule_properties: 0 + n_linear_layers: 1 From bc3981bead5b2ec1444d9f55180a39af17e8356e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 7 Oct 2025 18:32:55 +0200 Subject: [PATCH 216/224] set device for random initiated tensors --- chebai_graph/models/dynamic_gni.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 5e7ffe6..59f5f0d 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -37,6 +37,7 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): edge_dim=self.edge_dim, act=self.activation, ) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, batch: dict[str, Any]) -> Tensor: """ @@ -51,10 +52,15 @@ def forward(self, batch: dict[str, Any]) -> Tensor: graph_data = batch["features"][0] assert isinstance(graph_data, GraphData), "Expected GraphData instance" - random_x = torch.empty(graph_data.x.shape[0], graph_data.x.shape[1]) + random_x = torch.empty( + graph_data.x.shape[0], graph_data.x.shape[1], device=self.device + ) RandomFeatureInitializationReader.random_gni(random_x, self.distribution) + random_edge_attr = torch.empty( - graph_data.edge_attr.shape[0], graph_data.edge_attr.shape[1] + graph_data.edge_attr.shape[0], + graph_data.edge_attr.shape[1], + device=self.device, ) RandomFeatureInitializationReader.random_gni( random_edge_attr, self.distribution From 0b5f650fa241334725e65ea01729e66ac7bc511a Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 7 Oct 2025 19:34:42 +0200 Subject: [PATCH 217/224] add padding with zeros or randomness for node and edges --- chebai_graph/preprocessing/datasets/chebi.py | 98 ++++++++++++++++--- .../preprocessing/reader/static_gni.py | 4 +- 2 files changed, 90 insertions(+), 12 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 27e27e7..cdb406e 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -186,16 +186,50 @@ def _after_setup(self, **kwargs) -> None: class GraphPropertiesMixIn(DataPropertiesSetter, ABC): def __init__( - self, properties=None, transform=None, zero_pad_atom: int = None, **kwargs + self, + properties=None, + transform=None, + zero_pad_node: int = None, + zero_pad_edge: int = None, + random_pad_node: int = None, + random_pad_edge: int = None, + distribution: str = "normal", + **kwargs, ): super().__init__(properties, transform, **kwargs) - self.zero_pad_atom = int(zero_pad_atom) if zero_pad_atom is not None else None - if self.zero_pad_atom: + self.zero_pad_node = int(zero_pad_node) if zero_pad_node else None + if self.zero_pad_node: + print( + f"[Info] Node-level features will be zero-padded with " + f"{self.zero_pad_node} additional dimensions." + ) + + self.zero_pad_edge = int(zero_pad_edge) if zero_pad_edge else None + if self.zero_pad_edge: print( - f"[Info] Atom-level features will be zero-padded with " - f"{self.zero_pad_atom} additional dimensions." + f"[Info] Edge-level features will be zero-padded with " + f"{self.zero_pad_edge} additional dimensions." ) + self.random_pad_edge = int(random_pad_edge) if random_pad_edge else None + self.random_pad_node = int(random_pad_node) if random_pad_node else None + if self.random_pad_node or self.random_pad_edge: + assert ( + distribution is not None + and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS + ), "When using random padding, a valid distribution must be specified." + self.distribution = distribution + if self.random_pad_node: + print( + f"[Info] Node-level features will be padded with " + f"{self.random_pad_node} additional dimensions initialized from {self.distribution} distribution." + ) + if self.random_pad_edge: + print( + f"[Info] Edge-level features will be padded with " + f"{self.random_pad_edge} additional dimensions initialized from {self.distribution} distribution." + ) + if self.properties: print( f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}" @@ -242,8 +276,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: else: raise TypeError(f"Unsupported property type: {type(property).__name__}") - if self.zero_pad_atom is not None: - x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_atom))], dim=1) + if self.zero_pad_node: + x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_node))], dim=1) + + if self.zero_pad_edge: + edge_attr = torch.cat( + [edge_attr, torch.zeros((edge_attr.shape[0], self.zero_pad_edge))], + dim=1, + ) + + if self.random_pad_node: + random_pad = torch.empty((x.shape[0], self.random_pad_node)) + RandomFeatureInitializationReader.random_gni(random_pad, self.distribution) + x = torch.cat([x, random_pad], dim=1) + + if self.random_pad_edge: + random_pad = torch.empty((edge_attr.shape[0], self.random_pad_edge)) + RandomFeatureInitializationReader.random_gni(random_pad, self.distribution) + edge_attr = torch.cat([edge_attr, random_pad], dim=1) return GeomData( x=x, @@ -291,18 +341,44 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: prop_lengths = [ (prop.name, prop.encoder.get_encoding_length()) for prop in self.properties ] + + # -------------------------- Count total node properties n_node_properties = sum( p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty) ) - if self.zero_pad_atom: - n_node_properties += self.zero_pad_atom + + in_channels_str = f"in_channels: {n_node_properties}" + if self.zero_pad_node: + n_node_properties += self.zero_pad_node + in_channels_str += f"(with {self.zero_pad_node} padded zeros)" + + if self.random_pad_node: + n_node_properties += self.random_pad_node + in_channels_str += f"(with {self.random_pad_node} random padded values from {self.distribution} distribution)" + + # -------------------------- Count total edge properties + n_edge_properties = sum( + p.encoder.get_encoding_length() + for p in self.properties + if isinstance(p, BondProperty) + ) + edge_dim_str = f"edge_dim: {n_edge_properties}" + + if self.zero_pad_edge: + n_edge_properties += self.zero_pad_edge + edge_dim_str += f"(with {self.zero_pad_edge} padded zeros)" + + if self.random_pad_edge: + n_edge_properties += self.random_pad_edge + edge_dim_str += f"(with {self.random_pad_edge} random padded values from {self.distribution} distribution)" + rank_zero_info( f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" f"Use following values for given parameters for model configuration: \n\t" - f"in_channels: {n_node_properties} (with {self.zero_pad_atom} padded zeros) , " - f"edge_dim: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, " + f"{in_channels_str}, " + f"{edge_dim_str}, " f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" ) diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 347f6d5..398c08f 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -13,6 +13,8 @@ class RandomFeatureInitializationReader(GraphPropertyReader): + DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform"] + def __init__( self, num_node_properties: int, @@ -26,7 +28,7 @@ def __init__( self.num_node_properties = num_node_properties self.num_bond_properties = num_bond_properties self.num_molecule_properties = num_molecule_properties - assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] + assert distribution in self.DISTRIBUTIONS self.distribution = distribution def name(self) -> str: From 9c567534e4ac738a6d453e8412664e516ce0dc73 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 7 Oct 2025 19:50:55 +0200 Subject: [PATCH 218/224] for dynamic randomness set option for padding existing features with random features --- chebai_graph/models/dynamic_gni.py | 86 ++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 17 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 59f5f0d..7811e04 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -25,10 +25,34 @@ class ResGatedDynamicGNI(GraphModelBase): def __init__(self, config: dict[str, Any], **kwargs: Any): super().__init__(config=config, **kwargs) self.activation = ELU() # Instantiate ELU once for reuse. + distribution = config.get("distribution", "normal") - assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"] + assert distribution in RandomFeatureInitializationReader.DISTRIBUTIONS, ( + f"Unsupported distribution: {distribution}. " + f"Choose from {RandomFeatureInitializationReader.DISTRIBUTIONS}." + ) self.distribution = distribution + self.complete_randomness = config.get("complete_randomness", True) + + if not self.complete_randomness: + assert ( + "random_pad_node" in config or "random_pad_edge" in config + ), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False" + self.random_pad_node = ( + int(config["random_pad_node"]) + if config.get("random_pad_node") is not None + else None + ) + self.random_pad_edge = ( + int(config["random_pad_edge"]) + if config.get("random_pad_edge") is not None + else None + ) + assert ( + self.random_pad_node > 0 or self.random_pad_edge > 0 + ), "'random_pad_node' or 'random_pad_edge' must be positive integers" + self.resgated: BasicGNN = ResGatedModel( in_channels=self.in_channels, hidden_channels=self.hidden_channels, @@ -52,24 +76,52 @@ def forward(self, batch: dict[str, Any]) -> Tensor: graph_data = batch["features"][0] assert isinstance(graph_data, GraphData), "Expected GraphData instance" - random_x = torch.empty( - graph_data.x.shape[0], graph_data.x.shape[1], device=self.device - ) - RandomFeatureInitializationReader.random_gni(random_x, self.distribution) - - random_edge_attr = torch.empty( - graph_data.edge_attr.shape[0], - graph_data.edge_attr.shape[1], - device=self.device, - ) - RandomFeatureInitializationReader.random_gni( - random_edge_attr, self.distribution - ) - + new_x = None + new_edge_attr = None + if self.complete_randomness: + new_x = torch.empty( + graph_data.x.shape[0], graph_data.x.shape[1], device=self.device + ) + RandomFeatureInitializationReader.random_gni(new_x, self.distribution) + + new_edge_attr = torch.empty( + graph_data.edge_attr.shape[0], + graph_data.edge_attr.shape[1], + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + new_edge_attr, self.distribution + ) + else: + if self.random_pad_node is not None: + pad_node = torch.empty( + graph_data.x.shape[0], + self.random_pad_node, + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + pad_node, self.distribution + ) + new_x = torch.cat((graph_data.x, pad_node), dim=1) + + if self.random_pad_edge is not None: + pad_edge = torch.empty( + graph_data.edge_attr.shape[0], + self.random_pad_edge, + device=self.device, + ) + RandomFeatureInitializationReader.random_gni( + pad_edge, self.distribution + ) + new_edge_attr = torch.cat((graph_data.edge_attr, pad_edge), dim=1) + + assert ( + new_x is not None and new_edge_attr is not None + ), "Feature initialization failed" out = self.resgated( - x=random_x.float(), + x=new_x.float(), edge_index=graph_data.edge_index.long(), - edge_attr=random_edge_attr.float(), + edge_attr=new_edge_attr.float(), ) return self.activation(out) From a5070e8087cd34ca4f0165ff95591fc1e64163d3 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 7 Oct 2025 23:50:43 +0200 Subject: [PATCH 219/224] if padding applied, create separate data.pt file --- chebai_graph/preprocessing/datasets/chebi.py | 46 ++++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index cdb406e..798b3ba 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -349,14 +349,16 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: if isinstance(p, AtomProperty) ) - in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str = "" if self.zero_pad_node: n_node_properties += self.zero_pad_node - in_channels_str += f"(with {self.zero_pad_node} padded zeros)" + in_channels_str += f" (with {self.zero_pad_node} padded zeros)" if self.random_pad_node: n_node_properties += self.random_pad_node - in_channels_str += f"(with {self.random_pad_node} random padded values from {self.distribution} distribution)" + in_channels_str += f" (with {self.random_pad_node} random padded values from {self.distribution} distribution)" + + in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str # -------------------------- Count total edge properties n_edge_properties = sum( @@ -364,26 +366,54 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: for p in self.properties if isinstance(p, BondProperty) ) - edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str = "" if self.zero_pad_edge: n_edge_properties += self.zero_pad_edge - edge_dim_str += f"(with {self.zero_pad_edge} padded zeros)" + edge_dim_str += f" (with {self.zero_pad_edge} padded zeros)" if self.random_pad_edge: n_edge_properties += self.random_pad_edge - edge_dim_str += f"(with {self.random_pad_edge} random padded values from {self.distribution} distribution)" + edge_dim_str += f" (with {self.random_pad_edge} random padded values from {self.distribution} distribution)" + + edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str rank_zero_info( f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n" f"Use following values for given parameters for model configuration: \n\t" - f"{in_channels_str}, " - f"{edge_dim_str}, " + f"{in_channels_str} \n\t" + f"{edge_dim_str} \n\t" f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}" ) return base_df[base_data[0].keys()].to_dict("records") + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. + + Returns: + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. + """ + if self.n_token_limit is not None: + return {"data": f"data_maxlen{self.n_token_limit}.pt"} + + data_pt_filename = "data" + if self.zero_pad_node: + data_pt_filename += f"_zpn{self.zero_pad_node}" + if self.zero_pad_edge: + data_pt_filename += f"_zpe{self.zero_pad_edge}" + if self.random_pad_node: + data_pt_filename += f"_rpn{self.random_pad_node}" + if self.random_pad_edge: + data_pt_filename += f"_rpe{self.random_pad_edge}" + if self.random_pad_node or self.random_pad_edge: + data_pt_filename += f"_D{self.distribution}" + + return {"data": data_pt_filename + ".pt"} + class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): From 2d8d116a3ced34d87ac3bb9ef423b6aee4902127 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Thu, 9 Oct 2025 10:37:38 +0200 Subject: [PATCH 220/224] add ChEBI100 aug class --- chebai_graph/preprocessing/datasets/chebi.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 798b3ba..9833266 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -779,3 +779,7 @@ class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver5 class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50): READER = AtomFGReader_WithFGEdges_WithGraphNode + + +class ChEBI100_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver100): + READER = AtomFGReader_WithFGEdges_WithGraphNode From 04441312b0f930e6995d48126903f8672adfb5c5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 10 Oct 2025 13:10:10 +0200 Subject: [PATCH 221/224] minor changes for gni --- chebai_graph/models/dynamic_gni.py | 20 ++++++++++++++- chebai_graph/preprocessing/datasets/chebi.py | 26 -------------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 7811e04..123ac03 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -33,7 +33,11 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): ) self.distribution = distribution - self.complete_randomness = config.get("complete_randomness", True) + self.complete_randomness = ( + str(config.get("complete_randomness", "True")).lower() == "true" + ) + + print("Using complete randomness: ", self.complete_randomness) if not self.complete_randomness: assert ( @@ -44,11 +48,25 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): if config.get("random_pad_node") is not None else None ) + if self.random_pad_node is not None: + print( + f"[Info] Node features will be padded with {self.random_pad_node} " + f"new set of random features from distribution {self.distribution} " + f"in each forward pass." + ) + self.random_pad_edge = ( int(config["random_pad_edge"]) if config.get("random_pad_edge") is not None else None ) + if self.random_pad_edge is not None: + print( + f"[Info] Edge features will be padded with {self.random_pad_edge} " + f"new set of random features from distribution {self.distribution} " + f"in each forward pass." + ) + assert ( self.random_pad_node > 0 or self.random_pad_edge > 0 ), "'random_pad_node' or 'random_pad_edge' must be positive integers" diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 798b3ba..06f2fbe 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -388,32 +388,6 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: return base_df[base_data[0].keys()].to_dict("records") - @property - def processed_file_names_dict(self) -> dict: - """ - Returns a dictionary for the processed and tokenized data files. - - Returns: - dict: A dictionary mapping dataset keys to their respective file names. - For example, {"data": "data.pt"}. - """ - if self.n_token_limit is not None: - return {"data": f"data_maxlen{self.n_token_limit}.pt"} - - data_pt_filename = "data" - if self.zero_pad_node: - data_pt_filename += f"_zpn{self.zero_pad_node}" - if self.zero_pad_edge: - data_pt_filename += f"_zpe{self.zero_pad_edge}" - if self.random_pad_node: - data_pt_filename += f"_rpn{self.random_pad_node}" - if self.random_pad_edge: - data_pt_filename += f"_rpe{self.random_pad_edge}" - if self.random_pad_node or self.random_pad_edge: - data_pt_filename += f"_D{self.distribution}" - - return {"data": data_pt_filename + ".pt"} - class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): From ba9de368ee66584979652a3a031a5474aefafa05 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 10 Oct 2025 13:13:07 +0200 Subject: [PATCH 222/224] lint base file f --- chebai_graph/models/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 6b33799..a99b192 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -74,7 +74,11 @@ class GraphNetWrapper(GraphBaseNet, ABC): """ def __init__( - self, config: dict, n_linear_layers: int, n_molecule_properties: Optional[int] = 0, **kwargs + self, + config: dict, + n_linear_layers: int, + n_molecule_properties: Optional[int] = 0, + **kwargs, ) -> None: """ Initialize the GNN and linear layers. @@ -91,7 +95,9 @@ def __init__( self.activation = torch.nn.ELU self.lin_input_dim = self._get_lin_seq_input_dim( gnn_out_dim=gnn_out_dim, - n_molecule_properties=n_molecule_properties if n_molecule_properties is not None else 0, + n_molecule_properties=( + n_molecule_properties if n_molecule_properties is not None else 0 + ), ) lin_hidden_dim = kwargs.get("lin_hidden_dim", gnn_out_dim) From 1ba24e0a7a0ed928282bfb96b718a006252d741c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 10 Oct 2025 13:50:11 +0200 Subject: [PATCH 223/224] include zero padding random init reader static method --- chebai_graph/models/dynamic_gni.py | 36 ++++---- chebai_graph/preprocessing/datasets/chebi.py | 88 ++++++------------- .../preprocessing/reader/static_gni.py | 4 +- 3 files changed, 50 insertions(+), 78 deletions(-) diff --git a/chebai_graph/models/dynamic_gni.py b/chebai_graph/models/dynamic_gni.py index 123ac03..20a6c6a 100644 --- a/chebai_graph/models/dynamic_gni.py +++ b/chebai_graph/models/dynamic_gni.py @@ -41,35 +41,35 @@ def __init__(self, config: dict[str, Any], **kwargs: Any): if not self.complete_randomness: assert ( - "random_pad_node" in config or "random_pad_edge" in config - ), "Missing 'random_pad_node' or 'random_pad_edge' in config when complete_randomness is False" - self.random_pad_node = ( - int(config["random_pad_node"]) - if config.get("random_pad_node") is not None + "pad_node_features" in config or "pad_edge_features" in config + ), "Missing 'pad_node_features' or 'pad_edge_features' in config when complete_randomness is False" + self.pad_node_features = ( + int(config["pad_node_features"]) + if config.get("pad_node_features") is not None else None ) - if self.random_pad_node is not None: + if self.pad_node_features is not None: print( - f"[Info] Node features will be padded with {self.random_pad_node} " + f"[Info] Node features will be padded with {self.pad_node_features} " f"new set of random features from distribution {self.distribution} " f"in each forward pass." ) - self.random_pad_edge = ( - int(config["random_pad_edge"]) - if config.get("random_pad_edge") is not None + self.pad_edge_features = ( + int(config["pad_edge_features"]) + if config.get("pad_edge_features") is not None else None ) - if self.random_pad_edge is not None: + if self.pad_edge_features is not None: print( - f"[Info] Edge features will be padded with {self.random_pad_edge} " + f"[Info] Edge features will be padded with {self.pad_edge_features} " f"new set of random features from distribution {self.distribution} " f"in each forward pass." ) assert ( - self.random_pad_node > 0 or self.random_pad_edge > 0 - ), "'random_pad_node' or 'random_pad_edge' must be positive integers" + self.pad_node_features > 0 or self.pad_edge_features > 0 + ), "'pad_node_features' or 'pad_edge_features' must be positive integers" self.resgated: BasicGNN = ResGatedModel( in_channels=self.in_channels, @@ -111,10 +111,10 @@ def forward(self, batch: dict[str, Any]) -> Tensor: new_edge_attr, self.distribution ) else: - if self.random_pad_node is not None: + if self.pad_node_features is not None: pad_node = torch.empty( graph_data.x.shape[0], - self.random_pad_node, + self.pad_node_features, device=self.device, ) RandomFeatureInitializationReader.random_gni( @@ -122,10 +122,10 @@ def forward(self, batch: dict[str, Any]) -> Tensor: ) new_x = torch.cat((graph_data.x, pad_node), dim=1) - if self.random_pad_edge is not None: + if self.pad_edge_features is not None: pad_edge = torch.empty( graph_data.edge_attr.shape[0], - self.random_pad_edge, + self.pad_edge_features, device=self.device, ) RandomFeatureInitializationReader.random_gni( diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 94f0ece..c94b772 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -189,45 +189,29 @@ def __init__( self, properties=None, transform=None, - zero_pad_node: int = None, - zero_pad_edge: int = None, - random_pad_node: int = None, - random_pad_edge: int = None, + pad_node_features: int = None, + pad_edge_features: int = None, distribution: str = "normal", **kwargs, ): super().__init__(properties, transform, **kwargs) - self.zero_pad_node = int(zero_pad_node) if zero_pad_node else None - if self.zero_pad_node: - print( - f"[Info] Node-level features will be zero-padded with " - f"{self.zero_pad_node} additional dimensions." - ) - - self.zero_pad_edge = int(zero_pad_edge) if zero_pad_edge else None - if self.zero_pad_edge: - print( - f"[Info] Edge-level features will be zero-padded with " - f"{self.zero_pad_edge} additional dimensions." - ) - - self.random_pad_edge = int(random_pad_edge) if random_pad_edge else None - self.random_pad_node = int(random_pad_node) if random_pad_node else None - if self.random_pad_node or self.random_pad_edge: + self.pad_edge_features = int(pad_edge_features) if pad_edge_features else None + self.pad_node_features = int(pad_node_features) if pad_node_features else None + if self.pad_node_features or self.pad_edge_features: assert ( distribution is not None and distribution in RandomFeatureInitializationReader.DISTRIBUTIONS - ), "When using random padding, a valid distribution must be specified." + ), "When using padding for features, a valid distribution must be specified." self.distribution = distribution - if self.random_pad_node: + if self.pad_node_features: print( - f"[Info] Node-level features will be padded with " - f"{self.random_pad_node} additional dimensions initialized from {self.distribution} distribution." + f"[Info] Node-level features will be padded with random" + f"{self.pad_node_features} values from {self.distribution} distribution." ) - if self.random_pad_edge: + if self.pad_edge_features: print( - f"[Info] Edge-level features will be padded with " - f"{self.random_pad_edge} additional dimensions initialized from {self.distribution} distribution." + f"[Info] Edge-level features will be padded with random" + f"{self.pad_edge_features} values from {self.distribution} distribution." ) if self.properties: @@ -276,24 +260,19 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData: else: raise TypeError(f"Unsupported property type: {type(property).__name__}") - if self.zero_pad_node: - x = torch.cat([x, torch.zeros((x.shape[0], self.zero_pad_node))], dim=1) - - if self.zero_pad_edge: - edge_attr = torch.cat( - [edge_attr, torch.zeros((edge_attr.shape[0], self.zero_pad_edge))], - dim=1, + if self.pad_node_features: + padding_values = torch.empty((x.shape[0], self.pad_node_features)) + RandomFeatureInitializationReader.random_gni( + padding_values, self.distribution ) + x = torch.cat([x, padding_values], dim=1) - if self.random_pad_node: - random_pad = torch.empty((x.shape[0], self.random_pad_node)) - RandomFeatureInitializationReader.random_gni(random_pad, self.distribution) - x = torch.cat([x, random_pad], dim=1) - - if self.random_pad_edge: - random_pad = torch.empty((edge_attr.shape[0], self.random_pad_edge)) - RandomFeatureInitializationReader.random_gni(random_pad, self.distribution) - edge_attr = torch.cat([edge_attr, random_pad], dim=1) + if self.pad_edge_features: + padding_values = torch.empty((edge_attr.shape[0], self.pad_edge_features)) + RandomFeatureInitializationReader.random_gni( + padding_values, self.distribution + ) + edge_attr = torch.cat([edge_attr, padding_values], dim=1) return GeomData( x=x, @@ -350,13 +329,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: ) in_channels_str = "" - if self.zero_pad_node: - n_node_properties += self.zero_pad_node - in_channels_str += f" (with {self.zero_pad_node} padded zeros)" - - if self.random_pad_node: - n_node_properties += self.random_pad_node - in_channels_str += f" (with {self.random_pad_node} random padded values from {self.distribution} distribution)" + if self.pad_node_features: + n_node_properties += self.pad_node_features + in_channels_str += f" (with {self.pad_node_features} padded random values from {self.distribution} distribution)" in_channels_str = f"in_channels: {n_node_properties}" + in_channels_str @@ -367,14 +342,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]: if isinstance(p, BondProperty) ) edge_dim_str = "" - - if self.zero_pad_edge: - n_edge_properties += self.zero_pad_edge - edge_dim_str += f" (with {self.zero_pad_edge} padded zeros)" - - if self.random_pad_edge: - n_edge_properties += self.random_pad_edge - edge_dim_str += f" (with {self.random_pad_edge} random padded values from {self.distribution} distribution)" + if self.pad_edge_features: + n_edge_properties += self.pad_edge_features + edge_dim_str += f" (with {self.pad_edge_features} padded random values from {self.distribution} distribution)" edge_dim_str = f"edge_dim: {n_edge_properties}" + edge_dim_str diff --git a/chebai_graph/preprocessing/reader/static_gni.py b/chebai_graph/preprocessing/reader/static_gni.py index 398c08f..106c528 100644 --- a/chebai_graph/preprocessing/reader/static_gni.py +++ b/chebai_graph/preprocessing/reader/static_gni.py @@ -13,7 +13,7 @@ class RandomFeatureInitializationReader(GraphPropertyReader): - DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform"] + DISTRIBUTIONS = ["normal", "uniform", "xavier_normal", "xavier_uniform", "zeros"] def __init__( self, @@ -74,5 +74,7 @@ def random_gni(tensor: torch.Tensor, distribution: str) -> None: torch.nn.init.xavier_normal_(tensor) elif distribution == "xavier_uniform": torch.nn.init.xavier_uniform_(tensor) + elif distribution == "zeros": + torch.nn.init.zeros_(tensor) else: raise ValueError("Unknown distribution type") From a7915ec95030c1e720f91fabd0c48d35f40e65a3 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 13 Oct 2025 13:52:31 +0200 Subject: [PATCH 224/224] add optional batch norm --- chebai_graph/models/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chebai_graph/models/base.py b/chebai_graph/models/base.py index 6b33799..79ca12a 100644 --- a/chebai_graph/models/base.py +++ b/chebai_graph/models/base.py @@ -74,7 +74,7 @@ class GraphNetWrapper(GraphBaseNet, ABC): """ def __init__( - self, config: dict, n_linear_layers: int, n_molecule_properties: Optional[int] = 0, **kwargs + self, config: dict, n_linear_layers: int, n_molecule_properties: Optional[int] = 0, use_batch_norm: bool = False, **kwargs ) -> None: """ Initialize the GNN and linear layers. @@ -93,6 +93,9 @@ def __init__( gnn_out_dim=gnn_out_dim, n_molecule_properties=n_molecule_properties if n_molecule_properties is not None else 0, ) + self.use_batch_norm = use_batch_norm + if self.use_batch_norm: + self.batch_norm = torch.nn.BatchNorm1d(self.lin_input_dim) lin_hidden_dim = kwargs.get("lin_hidden_dim", gnn_out_dim) self.lin_sequential: torch.nn.Sequential = self._get_linear_module_list( @@ -180,6 +183,8 @@ def forward(self, batch: dict) -> torch.Tensor: a = self.gnn(batch) a = scatter_add(a, graph_data.batch, dim=0) a = torch.cat([a, graph_data.molecule_attr], dim=1) + if self.use_batch_norm: + a = self.batch_norm(a) return self.lin_sequential(a)