Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 82 additions & 57 deletions benchmarks/butina_clustering_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,14 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):
parser.add_argument("--no-rdkit", action="store_true", help="Disable RDKit benchmarks")
parser.add_argument("--no-fused", action="store_true", help="Disable fused Butina benchmarks")
parser.add_argument("--no-nvmolkit", action="store_true", help="Disable nvMolKit Butina benchmarks")
parser.add_argument("--no-rdkit-lowmem", action="store_true", help="Disable RDKit low-memory benchmarks")
parser.add_argument(
"--include-tanimoto-matrix",
"--rdkit-lowmem",
action="store_true",
help="Include tanimoto matrix calculation in butina timing (for fair comparison with fused_butina)",
help=(
"Enable the RDKit low-memory backend. Off by default because its "
"distance-matrix builder is a pure-Python O(n^2) loop that does "
"not finish in reasonable wall time at sizes >= 40k."
),
)
parser.add_argument(
"--config",
Expand All @@ -170,7 +173,6 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):
)
args = parser.parse_args()

include_tanimoto = args.include_tanimoto_matrix
n_runs = args.runs
disabled = set()
if args.no_rdkit:
Expand All @@ -179,7 +181,7 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):
disabled.add("fused")
if args.no_nvmolkit:
disabled.add("nvmolkit")
if args.no_rdkit_lowmem:
if not args.rdkit_lowmem:
disabled.add("rdkit_lowmem")

if args.config:
Expand Down Expand Up @@ -211,35 +213,22 @@ def bench_nvmol_with_tanimoto(fps, threshold, neighborlist_max_size):

mols = load_smiles(args.input_smiles_file, max_count=max_size + 100, sanitize=True, seed=args.seed)

if include_tanimoto and len(mols) < max_size:
print(
f"Error: --include-tanimoto-matrix requires at least {max_size} molecules, "
f"but only {len(mols)} were loaded from input",
file=sys.stderr,
)
sys.exit(1)

fps = get_fingerprints(mols)

# All three rdkit paths (cluster_only, with_tanimoto, lowmem) need real
# RDKit fingerprints, so build them once if any rdkit row is planned.
max_rdkit_fps_size = max(
(e["size"] for e in run_plan if "rdkit_lowmem" in e["run"] or ("rdkit" in e["run"] and include_tanimoto)),
(e["size"] for e in run_plan if "rdkit" in e["run"] or "rdkit_lowmem" in e["run"]),
default=0,
)
if max_rdkit_fps_size > 0:
if len(mols) < max_rdkit_fps_size:
print(
f"Error: rdkit benchmarks with fingerprints require at least {max_rdkit_fps_size} "
f"molecules, but only {len(mols)} were loaded from input",
file=sys.stderr,
)
sys.exit(1)
rdkit_fpgen = AllChem.GetMorganGenerator(radius=2, fpSize=1024)
rdkit_fps = [rdkit_fpgen.GetFingerprint(mol) for mol in mols]
else:
rdkit_fps = None

output_path = args.output
cutoffs = [args.cutoff] if args.cutoff is not None else [1e-10, 0.1, 0.2, 0.35, 1.1]
cutoffs = [args.cutoff] if args.cutoff is not None else [1e-10, 0.1, 0.2, 0.35, 1.0]
default_nl_sizes = [8, 16, 32, 64, 128]
results = []

Expand All @@ -254,34 +243,57 @@ def save_results():
runs = entry["run"]
max_nl_sizes = entry.get("neighborlist_sizes", default_nl_sizes)

need_dist = "nvmolkit" in runs or ("rdkit" in runs and not include_tanimoto)
if include_tanimoto:
fps_mat = fps[:size].contiguous()
if need_dist:
dist_mat = 1.0 - crossTanimotoSimilarity(fps_mat).torch()
# with_tanimoto and lowmem need real fingerprints for every mol up to
# `size`; when the input is smaller, those rows are skipped here.
have_real_fps_for_size = rdkit_fps is not None and len(rdkit_fps) >= size

need_real_fps_mat = "fused" in runs or "nvmolkit" in runs
if need_real_fps_mat and len(mols) >= size:
fps_mat_real = fps[:size].contiguous()
else:
fps_mat_real = None
if need_real_fps_mat and fps_mat_real is None:
fps_mat_synth = resize_and_fill_fingerprints(fps, size)
else:
fps_mat = resize_and_fill_fingerprints(fps, size)
if need_dist:
real_size = min(size, len(mols))
fps_mat_synth = None
fps_mat = fps_mat_real if fps_mat_real is not None else fps_mat_synth

need_dist = "nvmolkit" in runs or "rdkit" in runs
if need_dist:
real_size = min(size, len(mols))
if "nvmolkit" in runs or "fused" in runs:
base_dists = 1.0 - crossTanimotoSimilarity(fps[:real_size]).torch()
else:
rdkit_dist = np.empty((real_size, real_size), dtype=np.float64)
for i in range(real_size):
rdkit_dist[i] = BulkTanimotoSimilarity(rdkit_fps[i], rdkit_fps[:real_size])
np.subtract(1.0, rdkit_dist, out=rdkit_dist)
base_dists = torch.from_numpy(rdkit_dist)
if real_size >= size:
dist_mat = base_dists.contiguous()
else:
dist_mat = resize_and_fill(base_dists, size)
del base_dists

for cutoff in cutoffs:
# Don't run large sizes for edge cases.
if cutoff in (1e-10, 1.1) and size > 20000:
if cutoff in (1e-10, 1.0) and size > 20000:
continue

rdkit_time, rdk_std = float("nan"), float("nan")
rdkit_cluster_only_time, rdkit_cluster_only_std = float("nan"), float("nan")
rdkit_with_tanimoto_time, rdkit_with_tanimoto_std = float("nan"), float("nan")
if "rdkit" in runs:
if include_tanimoto:
rdkit_time, rdk_std = bench_rdkit_with_tanimoto(rdkit_fps[:size], cutoff, runs=n_runs)
else:
dist_mat_numpy = dist_mat.cpu().numpy()
rdkit_time, rdk_std = bench_rdkit(dist_mat_numpy, cutoff, runs=n_runs)
print(f"Running rdkit_cluster_only size {size} cutoff {cutoff}")
dist_mat_numpy = dist_mat.cpu().numpy()
rdkit_cluster_only_time, rdkit_cluster_only_std = bench_rdkit(dist_mat_numpy, cutoff, runs=n_runs)
if have_real_fps_for_size:
print(f"Running rdkit_with_tanimoto size {size} cutoff {cutoff}")
rdkit_with_tanimoto_time, rdkit_with_tanimoto_std = bench_rdkit_with_tanimoto(
rdkit_fps[:size], cutoff, runs=n_runs
)

rdkit_lm_time, rdkit_lm_std = float("nan"), float("nan")
if "rdkit_lowmem" in runs:
if "rdkit_lowmem" in runs and have_real_fps_for_size:
print(f"Running rdkit_lowmem size {size} cutoff {cutoff}")
rdkit_lm_time, rdkit_lm_std = bench_rdkit_lowmem(rdkit_fps[:size], cutoff, runs=n_runs)

Expand All @@ -297,20 +309,25 @@ def save_results():

if "nvmolkit" in runs:
for max_nl in max_nl_sizes:
print(f"Running nvmolkit size {size} cutoff {cutoff} max_nl {max_nl}")
if include_tanimoto:
nvmol_result = time_it(
lambda: bench_nvmol_with_tanimoto(fps_mat, cutoff, max_nl),
gpu_sync=True,
runs=n_runs,
)
else:
nvmol_result = time_it(
lambda: bench_nvmol_inner(dist_mat, cutoff, max_nl),
print(f"Running nvmolkit_cluster_only size {size} cutoff {cutoff} max_nl {max_nl}")
nvmolkit_cluster_only_result = time_it(
lambda: bench_nvmol_inner(dist_mat, cutoff, max_nl),
gpu_sync=True,
runs=n_runs,
)
nvmolkit_cluster_only_time = nvmolkit_cluster_only_result.mean_ms
nvmolkit_cluster_only_std = nvmolkit_cluster_only_result.std_ms

nvmolkit_with_tanimoto_time, nvmolkit_with_tanimoto_std = float("nan"), float("nan")
if fps_mat_real is not None:
print(f"Running nvmolkit_with_tanimoto size {size} cutoff {cutoff} max_nl {max_nl}")
nvmolkit_with_tanimoto_result = time_it(
lambda: bench_nvmol_with_tanimoto(fps_mat_real, cutoff, max_nl),
gpu_sync=True,
runs=n_runs,
)
nvmol_time, nvmol_std = nvmol_result.mean_ms, nvmol_result.std_ms
nvmolkit_with_tanimoto_time = nvmolkit_with_tanimoto_result.mean_ms
nvmolkit_with_tanimoto_std = nvmolkit_with_tanimoto_result.std_ms

nvmol_res = butina_nvmol(dist_mat, cutoff, neighborlist_max_size=max_nl).torch()
torch.cuda.synchronize()
Expand All @@ -325,12 +342,16 @@ def save_results():
"size": size,
"cutoff": cutoff,
"max_neighborlist_size": max_nl,
"rdkit_time_ms": rdkit_time,
"rdkit_std_ms": rdk_std,
"rdkit_cluster_only_time_ms": rdkit_cluster_only_time,
"rdkit_cluster_only_std_ms": rdkit_cluster_only_std,
"rdkit_with_tanimoto_time_ms": rdkit_with_tanimoto_time,
"rdkit_with_tanimoto_std_ms": rdkit_with_tanimoto_std,
"rdkit_lowmem_time_ms": rdkit_lm_time,
"rdkit_lowmem_std_ms": rdkit_lm_std,
"nvmol_time_ms": nvmol_time,
"nvmol_std_ms": nvmol_std,
"nvmolkit_cluster_only_time_ms": nvmolkit_cluster_only_time,
"nvmolkit_cluster_only_std_ms": nvmolkit_cluster_only_std,
"nvmolkit_with_tanimoto_time_ms": nvmolkit_with_tanimoto_time,
"nvmolkit_with_tanimoto_std_ms": nvmolkit_with_tanimoto_std,
"fused_butina_time_ms": fused_time,
"fused_butina_std_ms": fused_std,
}
Expand All @@ -341,12 +362,16 @@ def save_results():
"size": size,
"cutoff": cutoff,
"max_neighborlist_size": float("nan"),
"rdkit_time_ms": rdkit_time,
"rdkit_std_ms": rdk_std,
"rdkit_cluster_only_time_ms": rdkit_cluster_only_time,
"rdkit_cluster_only_std_ms": rdkit_cluster_only_std,
"rdkit_with_tanimoto_time_ms": rdkit_with_tanimoto_time,
"rdkit_with_tanimoto_std_ms": rdkit_with_tanimoto_std,
"rdkit_lowmem_time_ms": rdkit_lm_time,
"rdkit_lowmem_std_ms": rdkit_lm_std,
"nvmol_time_ms": float("nan"),
"nvmol_std_ms": float("nan"),
"nvmolkit_cluster_only_time_ms": float("nan"),
"nvmolkit_cluster_only_std_ms": float("nan"),
"nvmolkit_with_tanimoto_time_ms": float("nan"),
"nvmolkit_with_tanimoto_std_ms": float("nan"),
"fused_butina_time_ms": fused_time,
"fused_butina_std_ms": fused_std,
}
Expand Down
Loading