diff --git a/README.md b/README.md index c6d2fb0..5487922 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ SMILES → MQN Fingerprints → PQ Encoding → PQk-means Clustering → Nested **Key Features**: - **Scalability**: Stream billions of molecules without loading everything into memory - **Efficiency**: Compress 42-dimensional MQN vectors to 6-byte PQ codes (28x compression) +- **GPU acceleration**: Optional CUDA support for PQ encoding and cluster assignment (~25x speedup) - **Visualization**: Navigate from global overview to individual molecules in two clicks - **Accessibility**: Runs on commodity hardware (tested: AMD Ryzen 7, 64GB RAM) @@ -44,6 +45,52 @@ pip install -e . **Apple Silicon (M1/M2/M3)**: The `pqkmeans` library is not currently supported on Apple Silicon Macs. My plan is to rewrite pqkmeans with Silicon and GPU support but that's for a future release... For now, clustering functionality requires an x86_64 system. +## GPU Acceleration + +Both `PQEncoder.transform()` and `PQKMeans.predict()` support optional GPU acceleration via the `device` parameter. When a CUDA GPU is available, `device='auto'` (the default) uses the GPU transparently; otherwise it falls back to CPU. + +**Requirements**: `torch` and `triton` (both installed with `pip install torch`). + +```python +encoder = PQEncoder.load('encoder.joblib') +clusterer = PQKMeans.load('clusterer.joblib') + +# GPU is used automatically when available +pq_codes = encoder.transform(fingerprints) # device='auto' by default +labels = clusterer.predict(pq_codes) # device='auto' by default + +# Or force a specific device +labels_cpu = clusterer.predict(pq_codes, device='cpu') +labels_gpu = clusterer.predict(pq_codes, device='gpu') +``` + +**Benchmarks** (20M molecules, K=100,000 clusters, RTX 4070 Ti 16GB): + +| Step | GPU | CPU | Speedup | +|---:|---:|---:|---:| +| PQ Transform | 7.3s | 45.3s | 6.2x | +| Cluster Assignment | 29.9s | ~879s | 29.4x | + +Extrapolated to **9.6B molecules** (Enamine REAL): + +| Step | GPU | CPU | +|---:|---:|---:| +| PQ Transform | 59 min | 6.0 h | +| Cluster Assignment | 4.0 h | 117 h | +| **Combined** | **5.0 h** | **123 h** | + +The GPU implementation uses a custom Triton kernel for cluster assignment that tiles over centers with an online argmin, never materializing the N x K distance matrix. VRAM usage is ~10 bytes/point, so even an 8 GB GPU can process hundreds of millions of points per batch. + +To reproduce the benchmarks: + +```bash +# Decompress the test SMILES (if using the gzipped version) +gunzip -k data/10M_smiles.txt.gz + +# Run benchmark (pre-computes and caches fingerprints on first run) +python scripts/benchmark_gpu_predict.py +``` + ## Quick Start ```python diff --git a/chelombus/clustering/PyQKmeans.py b/chelombus/clustering/PyQKmeans.py index 6eb37b0..97b8168 100644 --- a/chelombus/clustering/PyQKmeans.py +++ b/chelombus/clustering/PyQKmeans.py @@ -15,6 +15,15 @@ from numba import njit, prange from chelombus import PQEncoder +_GPU_AVAILABLE = False +try: + import torch + if torch.cuda.is_available(): + from chelombus.clustering._gpu_predict import predict_gpu + _GPU_AVAILABLE = True +except ImportError: + pass + def _build_distance_tables(codewords: np.ndarray) -> np.ndarray: """Precompute symmetric squared-distance tables per subvector. @@ -124,14 +133,12 @@ def fit(self, X_train: np.ndarray) -> 'PQKMeans': self._centers_u8 = None return self - def predict(self, X: np.ndarray) -> np.ndarray: + def predict(self, X: np.ndarray, device: str = 'auto') -> np.ndarray: """Predict cluster labels for PQ codes. - Uses Numba JIT-compiled parallel assignment with precomputed - symmetric distance lookup tables - Args: X: PQ codes of shape (n_samples, n_subvectors), dtype uint8 + device: 'cpu' for Numba, 'gpu' for Triton/CUDA, 'auto' to pick GPU if available. Returns: Cluster labels of shape (n_samples,) @@ -141,7 +148,16 @@ def predict(self, X: np.ndarray) -> np.ndarray: if self._dtables is None: self._dtables = _build_distance_tables(self.encoder.codewords) self._centers_u8 = self.cluster_centers_.astype(np.uint8) - return _predict_numba(np.asarray(X, dtype=np.uint8), self._centers_u8, self._dtables) + + use_gpu = (device == 'gpu') or (device == 'auto' and _GPU_AVAILABLE) + codes = np.asarray(X, dtype=np.uint8) + + if use_gpu: + if not _GPU_AVAILABLE: + raise RuntimeError("GPU requested but CUDA/Triton not available") + return predict_gpu(codes, self._centers_u8, self._dtables) + + return _predict_numba(codes, self._centers_u8, self._dtables) def fit_predict(self, X: np.ndarray) -> np.ndarray: """Fit the model and predict cluster labels in one step. diff --git a/chelombus/clustering/_gpu_predict.py b/chelombus/clustering/_gpu_predict.py new file mode 100644 index 0000000..a911473 --- /dev/null +++ b/chelombus/clustering/_gpu_predict.py @@ -0,0 +1,202 @@ +"""GPU-accelerated PQ assignment using Triton kernels. + +Provides a drop-in replacement for _predict_numba that runs on CUDA GPUs. +The kernel computes symmetric PQ distance via lookup tables and maintains +an online argmin, never materializing the N x K distance matrix. + +VRAM budget per call +-------------------- +Resident (cached, allocated once): + centers: K × M bytes (100K × 6 = 600 KB) + dtables: M × 256 × 256 × 4 bytes (6 × 256 × 256 × 4 = 1.5 MB) + +Per-batch (freed after each batch): + codes: batch_n × M bytes (1 byte per subvector) + labels: batch_n × 4 bytes (int32) + → 10 bytes per point + +So for a given free VRAM of F bytes, max batch ≈ F / 10. +""" + +import numpy as np +import torch +import triton +import triton.language as tl + +# Fixed VRAM overhead for PyTorch/Triton context (conservative) +_VRAM_OVERHEAD_MB = 256 + + +@triton.jit +def _pq_assign_kernel( + codes_ptr, # (N, M) uint8 — PQ codes for data points + centers_ptr, # (K, M) uint8 — PQ codes for cluster centers + dtables_ptr, # (M, 256, 256) float32 — precomputed distance tables + labels_ptr, # (N,) int32 — output cluster assignments + N, # number of data points + K, # number of cluster centers + M: tl.constexpr, # number of subvectors + BLOCK_N: tl.constexpr, # number of points per program + BLOCK_K: tl.constexpr, # number of centers per tile +): + pid = tl.program_id(0) + point_offs = pid * BLOCK_N + tl.arange(0, BLOCK_N) + point_mask = point_offs < N + + best_dist = tl.full((BLOCK_N,), float('inf'), dtype=tl.float32) + best_label = tl.zeros((BLOCK_N,), dtype=tl.int32) + + # Pre-load point codes per subvector (M=6, unrolled) + pc0 = tl.load(codes_ptr + point_offs * M + 0, mask=point_mask, other=0).to(tl.int32) + pc1 = tl.load(codes_ptr + point_offs * M + 1, mask=point_mask, other=0).to(tl.int32) + pc2 = tl.load(codes_ptr + point_offs * M + 2, mask=point_mask, other=0).to(tl.int32) + pc3 = tl.load(codes_ptr + point_offs * M + 3, mask=point_mask, other=0).to(tl.int32) + pc4 = tl.load(codes_ptr + point_offs * M + 4, mask=point_mask, other=0).to(tl.int32) + pc5 = tl.load(codes_ptr + point_offs * M + 5, mask=point_mask, other=0).to(tl.int32) + + # Tile over centers + for c_start in range(0, K, BLOCK_K): + c_offs = c_start + tl.arange(0, BLOCK_K) + c_mask = c_offs < K + + cc0 = tl.load(centers_ptr + c_offs * M + 0, mask=c_mask, other=0).to(tl.int32) + cc1 = tl.load(centers_ptr + c_offs * M + 1, mask=c_mask, other=0).to(tl.int32) + cc2 = tl.load(centers_ptr + c_offs * M + 2, mask=c_mask, other=0).to(tl.int32) + cc3 = tl.load(centers_ptr + c_offs * M + 3, mask=c_mask, other=0).to(tl.int32) + cc4 = tl.load(centers_ptr + c_offs * M + 4, mask=c_mask, other=0).to(tl.int32) + cc5 = tl.load(centers_ptr + c_offs * M + 5, mask=c_mask, other=0).to(tl.int32) + + TABLE = 256 * 256 + + idx0 = pc0[:, None] * 256 + cc0[None, :] + dist = tl.load(dtables_ptr + 0 * TABLE + idx0) + + idx1 = pc1[:, None] * 256 + cc1[None, :] + dist += tl.load(dtables_ptr + 1 * TABLE + idx1) + + idx2 = pc2[:, None] * 256 + cc2[None, :] + dist += tl.load(dtables_ptr + 2 * TABLE + idx2) + + idx3 = pc3[:, None] * 256 + cc3[None, :] + dist += tl.load(dtables_ptr + 3 * TABLE + idx3) + + idx4 = pc4[:, None] * 256 + cc4[None, :] + dist += tl.load(dtables_ptr + 4 * TABLE + idx4) + + idx5 = pc5[:, None] * 256 + cc5[None, :] + dist += tl.load(dtables_ptr + 5 * TABLE + idx5) + + dist = tl.where(c_mask[None, :], dist, float('inf')) + + tile_min_dist = tl.min(dist, axis=1) + tile_min_idx = tl.argmin(dist, axis=1).to(tl.int32) + tile_min_label = c_start + tile_min_idx + + update_mask = tile_min_dist < best_dist + best_dist = tl.where(update_mask, tile_min_dist, best_dist) + best_label = tl.where(update_mask, tile_min_label, best_label) + + tl.store(labels_ptr + point_offs, best_label, mask=point_mask) + + +# --------------------------------------------------------------------------- +# GPU tensor cache (centers + dtables persist across predict calls) +# --------------------------------------------------------------------------- +_gpu_cache: dict = {} + + +def _get_or_upload(key: str, array: np.ndarray, dtype: torch.dtype) -> torch.Tensor: + """Upload numpy array to GPU, caching by (key, data_ptr, shape).""" + cache_key = (key, array.ctypes.data, array.shape) + cached = _gpu_cache.get(cache_key) + if cached is not None: + return cached + tensor = torch.from_numpy(np.ascontiguousarray(array)).to(dtype=dtype, device='cuda') + _gpu_cache[cache_key] = tensor + return tensor + + +def _auto_batch_size(N: int, M: int) -> int: + """Compute the max batch size that fits in available VRAM. + + Per-point VRAM: M bytes (codes) + 4 bytes (labels) = M + 4 + We leave _VRAM_OVERHEAD_MB for PyTorch/Triton context and cached tensors. + """ + free, total = torch.cuda.mem_get_info() + usable = free - _VRAM_OVERHEAD_MB * 1024 * 1024 + if usable < 0: + usable = free // 2 + bytes_per_point = M + 4 # uint8 codes + int32 label + max_batch = max(usable // bytes_per_point, 1024) + return min(max_batch, N) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def predict_gpu( + pq_codes: np.ndarray, + centers: np.ndarray, + dtables: np.ndarray, + batch_size: int = 0, +) -> np.ndarray: + """GPU-accelerated PQ assignment. + + Args: + pq_codes: (N, M) uint8 PQ codes. + centers: (K, M) uint8 cluster center codes. + dtables: (M, 256, 256) float32 distance lookup tables. + batch_size: Max points per GPU batch. + 0 (default) = auto-detect from free VRAM. + + Returns: + (N,) int64 cluster labels (same dtype as CPU path). + """ + N, M = pq_codes.shape + K = centers.shape[0] + + if M != 6: + raise ValueError(f"Triton kernel is compiled for M=6, got M={M}") + + # Kernel assumes dtables are (M, 256, 256). Pad if k_codebook < 256. + if dtables.shape[1] != 256 or dtables.shape[2] != 256: + padded = np.zeros((M, 256, 256), dtype=np.float32) + k_cb = dtables.shape[1] + padded[:, :k_cb, :k_cb] = dtables + dtables = padded + + # Cache centers and dtables on GPU (persist across calls) + centers_gpu = _get_or_upload('centers', centers, torch.uint8) + dtables_gpu = _get_or_upload('dtables', dtables, torch.float32) + + if batch_size <= 0: + batch_size = _auto_batch_size(N, M) + + labels_out = np.empty(N, dtype=np.int64) + + for start in range(0, N, batch_size): + end = min(start + batch_size, N) + chunk = pq_codes[start:end] + n_chunk = end - start + + codes_gpu = torch.from_numpy(np.ascontiguousarray(chunk, dtype=np.uint8)).cuda() + labels_gpu = torch.empty(n_chunk, dtype=torch.int32, device='cuda') + + _launch_kernel(codes_gpu, centers_gpu, dtables_gpu, labels_gpu, n_chunk, K, M) + + labels_out[start:end] = labels_gpu.cpu().numpy() + del codes_gpu, labels_gpu + + return labels_out + + +def _launch_kernel(codes, centers, dtables, labels, N, K, M): + BLOCK_N = 32 + BLOCK_K = 128 + grid = ((N + BLOCK_N - 1) // BLOCK_N,) + _pq_assign_kernel[grid]( + codes, centers, dtables, labels, + N, K, M, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, + ) diff --git a/chelombus/encoder/encoder.py b/chelombus/encoder/encoder.py index fef0192..7208d14 100644 --- a/chelombus/encoder/encoder.py +++ b/chelombus/encoder/encoder.py @@ -5,6 +5,14 @@ from pathlib import Path import joblib from numpy.typing import NDArray + +_GPU_AVAILABLE = False +try: + import torch + if torch.cuda.is_available(): + _GPU_AVAILABLE = True +except ImportError: + print('something went wrong') class PQEncoder(PQEncoderBase): """ Class to encode high-dimensional vectors into PQ-codes. @@ -88,50 +96,86 @@ def fit(self, X_train:NDArray, verbose:int=1, **kwargs)->None: del X_train # remove initial training data from memory - def transform(self, X:NDArray, verbose:int=1, **kwargs) -> NDArray: + def transform(self, X:NDArray, verbose:int=1, device:str='auto', **kwargs) -> NDArray: """ Transforms the input matrix X into its PQ-codes. - For each sample in X, the input vector is split into `m` equal-sized subvectors. + For each sample in X, the input vector is split into `m` equal-sized subvectors. Each subvector is assigned to the nearest cluster centroid - and the index of the closest centroid is stored. + and the index of the closest centroid is stored. The result is a compact representation of X, where each sample is encoded as a sequence of centroid indices. Args: - X (np.ndarray): Input data matrix of shape (n_samples, n_features), + X (np.ndarray): Input data matrix of shape (n_samples, n_features), where n_features must be divisible by the number of subvectors `m`. verbose(int): Level of verbosity. Default is 1 + device: 'cpu' for sklearn, 'gpu' for torch.cdist on CUDA, 'auto' to pick GPU if available. **kwargs: Optional keyword arguments passed to the underlying KMeans `predict()` function. Returns: - np.ndarray: PQ codes of shape (n_samples, m), where each element is the index of the nearest centroid + np.ndarray: PQ codes of shape (n_samples, m), where each element is the index of the nearest centroid for the corresponding subvector. """ assert self.encoder_is_trained == True, "PQEncoder must be trained before calling transform" + use_gpu = (device == 'gpu') or (device == 'auto' and _GPU_AVAILABLE) + if use_gpu: + if not _GPU_AVAILABLE: + raise RuntimeError("GPU requested but CUDA not available") + return self._transform_gpu(X) + + return self._transform_cpu(X, verbose, **kwargs) + + def _transform_cpu(self, X: NDArray, verbose: int = 1, **kwargs) -> NDArray: N, D = X.shape - # Store the index of the Nearest centroid for each subvector pq_codes = np.zeros((N, self.m), dtype=self.codebook_dtype) iterable = range(self.m) - # If our original vector is 1024 and our m (splits) is 8 then each subvector will be of dim= 1024/8 = 128 - if verbose > 0: + if verbose > 0: iterable = tqdm(iterable, desc='Generating PQ-codes', total=self.m) subvector_dim = int(D / self.m) for subvector_idx in iterable: - X_train_subvector = X[:, subvector_dim * subvector_idx : subvector_dim * (subvector_idx + 1)] - # For every subvector, run KMeans.predict(). Then look in the codebook for the index of the cluster that is closest - # Appends the centroid index to the pq_code. - pq_codes[:, subvector_idx] = self.pq_trained[subvector_idx].predict(X_train_subvector, **kwargs) - - # Free memory - del X + X_train_subvector = X[:, subvector_dim * subvector_idx : subvector_dim * (subvector_idx + 1)] + pq_codes[:, subvector_idx] = self.pq_trained[subvector_idx].predict(X_train_subvector, **kwargs) + + del X + return pq_codes + + def _transform_gpu(self, X: NDArray) -> NDArray: + """GPU-accelerated transform using torch.cdist + argmin. + + Batches points to avoid OOM: each batch produces an (batch, k) distance + matrix of k × 4 bytes per point. With k=256 that is 1 KB/point, so + a 2 GB budget ≈ 2M points per batch. + """ + N, D = X.shape + subvector_dim = int(D / self.m) + pq_codes = np.zeros((N, self.m), dtype=self.codebook_dtype) - # Return pq_codes (labels of the centroids for every subvector from the X_test data) + cw_gpu = [ + torch.from_numpy(self.codewords[sub].astype(np.float32)).cuda() + for sub in range(self.m) + ] + + # Budget: keep the (batch, k) distance matrix under ~2 GB + max_batch = max((2 * 1024**3) // (self.k * 4), 1024) + X_f32 = X.astype(np.float32) + + for start in range(0, N, max_batch): + end = min(start + max_batch, N) + for sub in range(self.m): + chunk = torch.from_numpy( + np.ascontiguousarray(X_f32[start:end, subvector_dim * sub : subvector_dim * (sub + 1)]) + ).cuda() + dists = torch.cdist(chunk, cw_gpu[sub]) + pq_codes[start:end, sub] = dists.argmin(dim=1).cpu().numpy() + del chunk, dists + + del X, cw_gpu return pq_codes def fit_transform(self, X:NDArray, verbose:int=1, **kwargs) -> NDArray: diff --git a/data/10M_smiles.txt.gz b/data/10M_smiles.txt.gz new file mode 100644 index 0000000..62ffa99 Binary files /dev/null and b/data/10M_smiles.txt.gz differ diff --git a/scripts/benchmark_gpu_predict.py b/scripts/benchmark_gpu_predict.py new file mode 100644 index 0000000..d0118e2 --- /dev/null +++ b/scripts/benchmark_gpu_predict.py @@ -0,0 +1,194 @@ +"""Benchmark: GPU vs CPU for PQ Transform and Predict. + +Pre-computes fingerprints to data/20M_fingerprints.npy so they can be +reloaded without recomputing. Fingerprint time is NOT included in the report. + +Usage: + python scripts/benchmark_gpu_predict.py [--n-points N] [--runs R] + python scripts/benchmark_gpu_predict.py # full 20M + python scripts/benchmark_gpu_predict.py --n-points 1000000 # quick test +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from chelombus import PQEncoder, FingerprintCalculator +from chelombus.clustering.PyQKmeans import PQKMeans +from chelombus.utils import format_time + + +FP_CACHE = Path("data/20M_fingerprints.npy") + + +def load_fingerprints(n_points: int) -> np.ndarray: + """Load pre-computed fingerprints, or compute and cache them.""" + if FP_CACHE.exists(): + print(f"Loading cached fingerprints from {FP_CACHE}...") + fps = np.load(FP_CACHE) + if n_points < len(fps): + fps = fps[:n_points] + print(f" shape={fps.shape}, dtype={fps.dtype}") + return fps + + # Compute and save — try multiple SMILES sources + smiles_path = None + for candidate in [ + Path("data/20M_smiles.txt"), + Path("data/10M_smiles.txt"), + ]: + if candidate.exists(): + smiles_path = candidate + break + # Try decompressing a gzipped version + if smiles_path is None: + for gz_candidate in [ + Path("data/20M_smiles.txt.gz"), + Path("data/10M_smiles.txt.gz"), + ]: + if gz_candidate.exists(): + smiles_path = gz_candidate.with_suffix("").with_suffix(".txt") + print(f"Decompressing {gz_candidate}...") + import gzip, shutil + with gzip.open(gz_candidate, "rt") as gz, open(smiles_path, "w") as out: + shutil.copyfileobj(gz, out) + break + if smiles_path is None or not smiles_path.exists(): + raise FileNotFoundError( + "No SMILES file found in data/. Expected 10M_smiles.txt(.gz) or 20M_smiles.txt(.gz)." + ) + print(f"No cached fingerprints found. Computing from {smiles_path}...") + smiles = [] + with open(smiles_path) as f: + for line in f: + smiles.append(line.strip()) + fp_calc = FingerprintCalculator() + fps = fp_calc.FingerprintFromSmiles(smiles, "mqn") + np.save(FP_CACHE, fps) + print(f" Saved to {FP_CACHE} ({FP_CACHE.stat().st_size / 1024**2:.1f} MB)") + if n_points < len(fps): + fps = fps[:n_points] + return fps + + +def timed_runs(fn, n_runs, label): + """Run fn() n_runs times, print each run, return median time.""" + times = [] + result = None + for r in range(n_runs): + t0 = time.perf_counter() + out = fn() + elapsed = time.perf_counter() - t0 + times.append(elapsed) + if result is None: + result = out + print(f" {label} run {r+1}: {elapsed:.4f}s") + return np.median(times), result + + +def benchmark_transform(encoder, fps, n_warmup, n_runs): + N = fps.shape[0] + sep = "=" * 60 + print(f"\n{sep}") + print(f"TRANSFORM: {N:,} fingerprints → PQ codes (m={encoder.m}, k={encoder.k})") + print(sep) + + results = {} + for device in ["gpu", "cpu"]: + tag = device.upper() + + # Warmup + for _ in range(n_warmup): + encoder.transform(fps[:1000], verbose=0, device=device) + + med, codes = timed_runs( + lambda d=device: encoder.transform(fps, verbose=0, device=d), + n_runs, tag, + ) + throughput = N / med + results[device] = {"median": med, "throughput": throughput, "codes": codes} + print(f" → {tag} median: {med:.3f}s ({throughput:,.0f} mol/s)") + + gpu, cpu = results["gpu"], results["cpu"] + mismatches = int(np.sum(gpu["codes"] != cpu["codes"])) + + print(f"\n Speedup: {cpu['median']/gpu['median']:.1f}x") + print(f" Correctness: {mismatches}/{gpu['codes'].size} mismatches " + f"({100*mismatches/gpu['codes'].size:.3f}% — float32 tie-breaking)") + print(f" Extrapolation:") + for label, total in [("1B", 1e9), ("9.6B", 9.6e9)]: + print(f" {label}: GPU {format_time(total/gpu['throughput'])} | CPU {format_time(total/cpu['throughput'])}") + + return gpu["codes"] + + +def benchmark_predict(clusterer, pq_codes, n_warmup, n_runs): + N = pq_codes.shape[0] + K = clusterer.k + sep = "=" * 60 + print(f"\n{sep}") + print(f"PREDICT: {N:,} PQ codes → cluster labels (K={K:,})") + print(sep) + + results = {} + for device in ["gpu", "cpu"]: + tag = device.upper() + # CPU with K=100K is very slow; limit subset + n_bench = N + bench_codes = pq_codes[:n_bench] + + # Warmup + for _ in range(n_warmup): + clusterer.predict(bench_codes[:1000], device=device) + + med, labels = timed_runs( + lambda d=device, bc=bench_codes: clusterer.predict(bc, device=d), + n_runs, tag, + ) + throughput = n_bench / med + results[device] = { + "median": med, "throughput": throughput, "labels": labels, "n": n_bench, + } + print(f" → {tag} median: {med:.3f}s for {n_bench:,} points ({throughput:,.0f} codes/sec)") + + gpu, cpu = results["gpu"], results["cpu"] + overlap = min(gpu["n"], cpu["n"]) + match = int(np.sum(gpu["labels"][:overlap] == cpu["labels"][:overlap])) + + print(f"\n Speedup: {gpu['throughput']/cpu['throughput']:.1f}x") + print(f" Correctness: {match:,}/{overlap:,} match ({100*match/overlap:.2f}%)") + print(f" Extrapolation:") + for label, total in [("1B", 1e9), ("9.6B", 9.6e9)]: + print(f" {label}: GPU {format_time(total/gpu['throughput'])} | CPU {format_time(total/cpu['throughput'])}") + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark GPU vs CPU: Transform + Predict") + parser.add_argument("--n-points", type=int, default=0, + help="Number of fingerprints (0 = use all cached, default: all 20M)") + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--encoder", default="models/paper_encoder.joblib") + parser.add_argument("--clusterer", default="models/paper_clusterer.joblib") + args = parser.parse_args() + + print("Loading models...") + encoder = PQEncoder.load(args.encoder) + clusterer = PQKMeans.load(args.clusterer) + print(f" Encoder: m={encoder.m}, k={encoder.k}") + print(f" Clusterer: K={clusterer.k:,}") + + n = args.n_points if args.n_points > 0 else 20_000_001 + fps = load_fingerprints(n) + + pq_codes = benchmark_transform(encoder, fps, args.warmup, args.runs) + benchmark_predict(clusterer, pq_codes, args.warmup, args.runs) + + +if __name__ == "__main__": + main() diff --git a/scripts/cluster_smiles.py b/scripts/cluster_smiles.py new file mode 100644 index 0000000..9c59cc4 --- /dev/null +++ b/scripts/cluster_smiles.py @@ -0,0 +1,116 @@ +"""Cluster SMILES using pre-trained PQEncoder + PQKMeans models. + +Usage: + python scripts/cluster_smiles.py \ + --input /mnt/10tb_hdd/cleaned_enamine_10b/output_file_0.cxsmiles \ + --output /mnt/samsung_2tb/tmp/ \ + --smiles-col 1 \ + --chunksize 1000000 + +Uses GPU for cluster assignment when available (device='auto'). +""" + +import argparse +import os +import sys +import time +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from chelombus import PQEncoder, DataStreamer, FingerprintCalculator +from chelombus.clustering.PyQKmeans import PQKMeans + + +def format_time(seconds): + if seconds < 60: + return f"{seconds:.0f}s" + elif seconds < 3600: + return f"{seconds/60:.1f}m" + else: + return f"{seconds/3600:.1f}h" + + +def main(): + parser = argparse.ArgumentParser(description="Cluster SMILES with pre-trained models") + parser.add_argument("--input", required=True, help="Input SMILES file") + parser.add_argument("--output", required=True, help="Output directory for parquet files") + parser.add_argument("--encoder", default="models/paper_encoder.joblib") + parser.add_argument("--clusterer", default="models/paper_clusterer.joblib") + parser.add_argument("--chunksize", type=int, default=1_000_000) + parser.add_argument("--smiles-col", type=int, default=0, + help="Column index for SMILES (0-indexed)") + parser.add_argument("--device", default="auto", choices=["auto", "gpu", "cpu"]) + parser.add_argument("--resume", action="store_true", + help="Skip chunks that already have output files") + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + + print(f"Loading encoder from {args.encoder}") + encoder = PQEncoder.load(args.encoder) + print(f"Loading clusterer from {args.clusterer}") + clusterer = PQKMeans.load(args.clusterer) + print(f" k={clusterer.k:,} clusters, m={encoder.m} subvectors") + + stream = DataStreamer() + fp_calc = FingerprintCalculator() + + total_molecules = 0 + total_time = 0 + start = time.perf_counter() + + for i, chunk in enumerate(stream.parse_input( + args.input, chunksize=args.chunksize, + smiles_col=args.smiles_col, verbose=0, + )): + # Resume support: skip if output file exists + out_file = os.path.join(args.output, f"chunk_{i:05d}.parquet") + if args.resume and os.path.exists(out_file): + total_molecules += len(chunk) + continue + + t0 = time.perf_counter() + + # SMILES to MQN fingerprints + fps = fp_calc.FingerprintFromSmiles(chunk, "mqn") + + # MQN to PQ codes (GPU when available) + pq_codes = encoder.transform(fps, verbose=0, device=args.device) + + # PQ codes to cluster labels (GPU when available) + labels = clusterer.predict(pq_codes, device=args.device) + + # Build output only include rows where fingerprint succeeded + if len(fps) == len(chunk): + table = pa.table({"smiles": chunk, "cluster_id": labels}) + else: + table = pa.table({"cluster_id": labels}) + + pq.write_table(table, out_file) + + elapsed = time.perf_counter() - t0 + total_molecules += len(chunk) + total_time += elapsed + rate = total_molecules / (time.perf_counter() - start) + + print( + f"\rChunk {i:>5d} | {total_molecules:>12,} molecules | " + f"{rate:,.0f} mol/s | chunk: {elapsed:.1f}s | " + f"ETA: {format_time((9_600_000_000 - total_molecules) / rate) if rate > 0 else '?'}", + end="", flush=True, + ) + + del chunk, fps, pq_codes, labels, table + + elapsed_total = time.perf_counter() - start + print(f"\n\nDone: {total_molecules:,} molecules in {format_time(elapsed_total)}") + print(f"Output: {args.output}") + + +if __name__ == "__main__": + main()