diff --git a/debug/bench_pairwise_distance.py b/debug/bench_pairwise_distance.py new file mode 100644 index 00000000..4b1fcdbc --- /dev/null +++ b/debug/bench_pairwise_distance.py @@ -0,0 +1,69 @@ +"""Benchmark pairwise_distance: Gram matrix trick vs batched pairwise.""" + +import time +import numpy as np +import cupy as cp +from pg_gpu import HaplotypeMatrix +from pg_gpu.decomposition import pairwise_distance + + +def make_data(n_haps, n_snps, seed=42): + rng = np.random.default_rng(seed) + founders = rng.integers(0, 2, size=(5, n_snps), dtype=np.int8) + assignments = rng.integers(0, 5, size=n_haps) + haps = founders[assignments].copy() + mutations = rng.random(size=(n_haps, n_snps)) < 0.02 + haps ^= mutations.astype(np.int8) + positions = np.arange(n_snps) * 100 + hm = HaplotypeMatrix(haps, positions, positions[0], positions[-1]) + hm.transfer_to_gpu() + return hm + + +def bench(fn, n_reps=5): + fn() + cp.cuda.Device(0).synchronize() + times = [] + for _ in range(n_reps): + cp.cuda.Device(0).synchronize() + t0 = time.perf_counter() + fn() + cp.cuda.Device(0).synchronize() + times.append(time.perf_counter() - t0) + return np.median(times) * 1000 + + +def main(): + # Correctness: compare euclidean distances with scipy reference + print("Correctness check:") + from scipy.spatial.distance import pdist + for n_haps, n_snps in [(20, 200), (100, 2000)]: + hm = make_data(n_haps, n_snps) + d_gpu = pairwise_distance(hm, metric='euclidean') + hap_cpu = hm.haplotypes.get().astype(np.float64) + d_ref = pdist(hap_cpu, metric='euclidean') + max_diff = np.max(np.abs(d_gpu - d_ref)) + rel_err = max_diff / np.max(np.abs(d_ref)) + status = "PASS" if rel_err < 1e-10 else "FAIL" + print(f" {n_haps} haps x {n_snps} snps: max_rel_err={rel_err:.2e} {status}") + + # Speed + configs = [ + (100, 10000), + (100, 50000), + (200, 10000), + (200, 50000), + (200, 100000), + ] + + print(f"\n{'n_haps':>7} {'n_snps':>8} | {'euclidean (ms)':>14} {'sqeuclidean (ms)':>16}") + print("-" * 55) + for n_haps, n_snps in configs: + hm = make_data(n_haps, n_snps) + t_euc = bench(lambda: pairwise_distance(hm, metric='euclidean')) + t_sqe = bench(lambda: pairwise_distance(hm, metric='sqeuclidean')) + print(f"{n_haps:>7} {n_snps:>8} | {t_euc:>14.2f} {t_sqe:>16.2f}") + + +if __name__ == "__main__": + main() diff --git a/pg_gpu/decomposition.py b/pg_gpu/decomposition.py index 6d621418..a6970292 100644 --- a/pg_gpu/decomposition.py +++ b/pg_gpu/decomposition.py @@ -358,20 +358,46 @@ def pairwise_distance(haplotype_matrix: HaplotypeMatrix, complete = missing_per_var == 0 hap = hap[:, complete] - X = cp.where(hap >= 0, hap, 0).astype(cp.float64) - valid_mask = (hap >= 0).astype(cp.float64) - has_missing = cp.any(hap < 0) - n = X.shape[0] + has_missing = bool(cp.any(hap < 0).get()) + n = hap.shape[0] + m = hap.shape[1] if metric in ('euclidean', 'sqeuclidean', 'cityblock'): + # Fast path: chunked Gram trick for euclidean/sqeuclidean + # without missing data. Never materializes the full (n, m) + # float64 matrix -- only (n, chunk_size) slices at a time. + if not has_missing and metric in ('euclidean', 'sqeuclidean'): + from ._memutil import (estimate_variant_chunk_size, + free_gpu_pool) + chunk_size = estimate_variant_chunk_size( + n, bytes_per_element=8, n_intermediates=2) + G = cp.zeros((n, n), dtype=cp.float64) + for col_start in range(0, m, chunk_size): + col_end = min(col_start + chunk_size, m) + X_chunk = hap[:, col_start:col_end].astype(cp.float64) + G += X_chunk @ X_chunk.T + del X_chunk + free_gpu_pool() + # d²(i,j) = ||x_i||² + ||x_j||² - 2*x_i·x_j + norms_sq = cp.diag(G) + D2 = norms_sq[:, None] + norms_sq[None, :] - 2.0 * G + D2 = cp.maximum(D2, 0.0) + idx_i, idx_j = cp.triu_indices(n, k=1) + d = D2[idx_i, idx_j] + if metric == 'euclidean': + d = cp.sqrt(d) + return d.get() + + # General batched path (missing data or cityblock). + # Needs X and valid_mask, but these are accessed by row pairs + # so the full matrix is required. + X = cp.where(hap >= 0, hap, 0).astype(cp.float64) + valid_mask = (hap >= 0).astype(cp.float64) idx_i, idx_j = cp.triu_indices(n, k=1) n_pairs = len(idx_i) - # Estimate batch size from available GPU memory - n_variants = X.shape[1] free_mem = cp.cuda.Device().mem_info[0] - # Each pair needs ~3 float64 arrays of n_variants (diff, joint, result) - bytes_per_pair = n_variants * 8 * 3 + bytes_per_pair = m * 8 * 3 batch_size = max(1, min(n_pairs, int(free_mem * 0.3 / bytes_per_pair))) dist_parts = [] @@ -381,20 +407,18 @@ def pairwise_distance(haplotype_matrix: HaplotypeMatrix, bj = idx_j[start:end] if has_missing: - # only compare at jointly-valid sites joint = valid_mask[bi] * valid_mask[bj] n_joint = cp.sum(joint, axis=1) else: - n_joint = cp.float64(X.shape[1]) + n_joint = cp.float64(m) if metric == 'cityblock': raw = cp.sum(cp.abs(X[bi] - X[bj]) * (joint if has_missing else 1.0), axis=1) else: raw = cp.sum(((X[bi] - X[bj]) ** 2) * (joint if has_missing else 1.0), axis=1) - # normalize by jointly-valid sites if has_missing: - d = cp.where(n_joint > 0, raw * X.shape[1] / n_joint, 0.0) + d = cp.where(n_joint > 0, raw * m / n_joint, 0.0) else: d = raw @@ -405,6 +429,7 @@ def pairwise_distance(haplotype_matrix: HaplotypeMatrix, return cp.concatenate(dist_parts).get() else: from scipy.spatial.distance import pdist + X = cp.where(hap >= 0, hap, 0).astype(cp.float64) X_cpu = X.get() return pdist(X_cpu, metric=metric)