diff --git a/benchmarks/etkdg_bench.py b/benchmarks/etkdg_bench.py index 0c8402e..f8613de 100644 --- a/benchmarks/etkdg_bench.py +++ b/benchmarks/etkdg_bench.py @@ -48,6 +48,7 @@ from nvmolkit.types import HardwareOptions from rdkit import Chem from rdkit.Chem import AllChem, rdDistGeom +from tqdm.contrib.concurrent import process_map OPTUNA_AVAILABLE = nv_autotune.is_available() @@ -74,18 +75,52 @@ def _mmff_energies(mol: Chem.Mol) -> list[float | None]: return energies +def _mmff_energies_from_binary(mol_bytes: bytes) -> list[float | None]: + """Evaluate MMFF energies from a serialized Mol (picklable for worker processes).""" + return _mmff_energies(Chem.Mol(mol_bytes)) + + def _energy_diff_summary( rdkit_mols: list[Chem.Mol], nvmolkit_mols: list[Chem.Mol], + num_threads: int = 1, ) -> tuple[float, float, int]: """Mean / median energy difference (RDKit - nvmolkit) and the number of paired conformers. Conformers where either side failed to evaluate (``None``) are skipped. + Energy evaluations across mols run in parallel when ``num_threads > 1``. """ + paired_count = min(len(rdkit_mols), len(nvmolkit_mols)) + if paired_count == 0: + return float("nan"), float("nan"), 0 + + rd_paired = rdkit_mols[:paired_count] + nv_paired = nvmolkit_mols[:paired_count] + + if num_threads > 1: + rd_binaries = [m.ToBinary() for m in rd_paired] + nv_binaries = [m.ToBinary() for m in nv_paired] + chunksize = max(1, paired_count // (num_threads * 8) or 1) + rd_energies_list = process_map( + _mmff_energies_from_binary, + rd_binaries, + max_workers=num_threads, + chunksize=chunksize, + desc="Energy validation (RDKit)", + ) + nv_energies_list = process_map( + _mmff_energies_from_binary, + nv_binaries, + max_workers=num_threads, + chunksize=chunksize, + desc="Energy validation (nvmolkit)", + ) + else: + rd_energies_list = [_mmff_energies(m) for m in rd_paired] + nv_energies_list = [_mmff_energies(m) for m in nv_paired] + deltas: list[float] = [] - for rd_mol, nv_mol in zip(rdkit_mols, nvmolkit_mols): - rd_energies = _mmff_energies(rd_mol) - nv_energies = _mmff_energies(nv_mol) + for rd_energies, nv_energies in zip(rd_energies_list, nv_energies_list): paired = min(len(rd_energies), len(nv_energies)) for i in range(paired): rd_energy = rd_energies[i] @@ -514,7 +549,9 @@ def main() -> None: diff_computed = False if args.validate and "nvmolkit" in results and "rdkit" in results: print("\nValidation (MMFF94 energies)...") - energy_mean, energy_median, energy_pairs = _energy_diff_summary(results["rdkit"][1], results["nvmolkit"][1]) + energy_mean, energy_median, energy_pairs = _energy_diff_summary( + results["rdkit"][1], results["nvmolkit"][1], num_threads=max(1, args.rdkit_threads) + ) diff_computed = energy_pairs > 0 if diff_computed: print(