From a850adc9143926996217109671c1feb358513465 Mon Sep 17 00:00:00 2001 From: Kevin Boyd Date: Fri, 29 May 2026 14:09:31 -0400 Subject: [PATCH 1/2] etkdg: parallelize MMFF energy validation across mols `_energy_diff_summary` gains a `num_threads` parameter; when > 1, MMFF energy evaluation runs in worker processes via `process_map`. Wired up from `main()` with `args.rdkit_threads`, so the same flag that controls the timed RDKit run also speeds up the (untimed) validation step. --- benchmarks/etkdg_bench.py | 45 +++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/benchmarks/etkdg_bench.py b/benchmarks/etkdg_bench.py index 0c8402e8..1c5e2045 100644 --- a/benchmarks/etkdg_bench.py +++ b/benchmarks/etkdg_bench.py @@ -48,6 +48,7 @@ from nvmolkit.types import HardwareOptions from rdkit import Chem from rdkit.Chem import AllChem, rdDistGeom +from tqdm.contrib.concurrent import process_map OPTUNA_AVAILABLE = nv_autotune.is_available() @@ -74,18 +75,52 @@ def _mmff_energies(mol: Chem.Mol) -> list[float | None]: return energies +def _mmff_energies_from_binary(mol_bytes: bytes) -> list[float | None]: + """Multiprocessing-friendly wrapper: rebuild Mol from bytes then evaluate.""" + return _mmff_energies(Chem.Mol(mol_bytes)) + + def _energy_diff_summary( rdkit_mols: list[Chem.Mol], nvmolkit_mols: list[Chem.Mol], + num_threads: int = 1, ) -> tuple[float, float, int]: """Mean / median energy difference (RDKit - nvmolkit) and the number of paired conformers. Conformers where either side failed to evaluate (``None``) are skipped. + Energy evaluations across mols run in parallel when ``num_threads > 1``. """ + paired_count = min(len(rdkit_mols), len(nvmolkit_mols)) + if paired_count == 0: + return float("nan"), float("nan"), 0 + + rd_paired = rdkit_mols[:paired_count] + nv_paired = nvmolkit_mols[:paired_count] + + if num_threads > 1: + rd_binaries = [m.ToBinary() for m in rd_paired] + nv_binaries = [m.ToBinary() for m in nv_paired] + chunksize = max(1, paired_count // (num_threads * 8) or 1) + rd_energies_list = process_map( + _mmff_energies_from_binary, + rd_binaries, + max_workers=num_threads, + chunksize=chunksize, + desc="Energy validation (RDKit)", + ) + nv_energies_list = process_map( + _mmff_energies_from_binary, + nv_binaries, + max_workers=num_threads, + chunksize=chunksize, + desc="Energy validation (nvmolkit)", + ) + else: + rd_energies_list = [_mmff_energies(m) for m in rd_paired] + nv_energies_list = [_mmff_energies(m) for m in nv_paired] + deltas: list[float] = [] - for rd_mol, nv_mol in zip(rdkit_mols, nvmolkit_mols): - rd_energies = _mmff_energies(rd_mol) - nv_energies = _mmff_energies(nv_mol) + for rd_energies, nv_energies in zip(rd_energies_list, nv_energies_list): paired = min(len(rd_energies), len(nv_energies)) for i in range(paired): rd_energy = rd_energies[i] @@ -514,7 +549,9 @@ def main() -> None: diff_computed = False if args.validate and "nvmolkit" in results and "rdkit" in results: print("\nValidation (MMFF94 energies)...") - energy_mean, energy_median, energy_pairs = _energy_diff_summary(results["rdkit"][1], results["nvmolkit"][1]) + energy_mean, energy_median, energy_pairs = _energy_diff_summary( + results["rdkit"][1], results["nvmolkit"][1], num_threads=max(1, args.rdkit_threads) + ) diff_computed = energy_pairs > 0 if diff_computed: print( From c4d2d903c28f9c1c52c288030310b9b3cda653e8 Mon Sep 17 00:00:00 2001 From: Kevin Boyd Date: Fri, 29 May 2026 17:35:53 -0400 Subject: [PATCH 2/2] etkdg: trim energy-worker docstring --- benchmarks/etkdg_bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/etkdg_bench.py b/benchmarks/etkdg_bench.py index 1c5e2045..f8613de0 100644 --- a/benchmarks/etkdg_bench.py +++ b/benchmarks/etkdg_bench.py @@ -76,7 +76,7 @@ def _mmff_energies(mol: Chem.Mol) -> list[float | None]: def _mmff_energies_from_binary(mol_bytes: bytes) -> list[float | None]: - """Multiprocessing-friendly wrapper: rebuild Mol from bytes then evaluate.""" + """Evaluate MMFF energies from a serialized Mol (picklable for worker processes).""" return _mmff_energies(Chem.Mol(mol_bytes))