From bb81e9f96f9f7fca4fbdeb2c39b5f95338ccf83f Mon Sep 17 00:00:00 2001 From: Andrew Kern Date: Fri, 12 Jun 2026 17:22:43 -0700 Subject: [PATCH] fix NaN in genotype LD stats when missing data shrinks per-pair sample counts Missing data can drop a population's per-pair valid sample count below the order of a statistic's denominator (n*(n-1)*(n-2) for between-pop Dz, 1/n for the normalized pi2 frequencies, n*(n-1) for heterozygosity). The denominator is then zero and the unbiased estimator is 0/0, which is undefined in moments too. The between-pop Dz kernels, the pi2 normalized-frequency path, and the heterozygosity sum divided without guarding, so one degenerate pair or site emitted a NaN that poisoned the whole r-bin sum. Guard each denominator so a degenerate pair/site contributes 0 -- excluding the undefined term -- matching the single-pop Dz and pi2 kernels that already guard the same way. On well-sampled data the guards are exact no-ops; pg_gpu still matches moments' Python reference (the PR #258 -20 pi2 coefficient) to machine precision across every DD/Dz/pi2 case. --- pg_gpu/genotype_kernels.py | 29 ++++- pg_gpu/moments_ld.py | 12 +- tests/test_ld_missing_data_degenerate.py | 159 +++++++++++++++++++++++ 3 files changed, 192 insertions(+), 8 deletions(-) create mode 100644 tests/test_ld_missing_data_degenerate.py diff --git a/pg_gpu/genotype_kernels.py b/pg_gpu/genotype_kernels.py index 226d7596..955da129 100644 --- a/pg_gpu/genotype_kernels.py +++ b/pg_gpu/genotype_kernels.py @@ -196,13 +196,22 @@ def _launch(kernel, args, M): Vcoef[idx]=d*(1.+a)+.25*MB; }''', "k", options=("-std=c++11",)) +# The between-pop Dz denominators are products of falling factorials of the +# per-pair valid sample counts (n, n*(n-1), n*(n-1)*(n-2)). With missing data a +# pop's valid count at a pair can fall below the order of its factorial -- e.g. +# a 14-sample pop with only two individuals typed at both loci gives +# n*(n-1)*(n-2)=0 -- so the unguarded ratio was 0/0 = NaN, and one such pair +# poisoned the whole bin sum. A degenerate pair contributes nothing, so emit 0 +# when the denominator is non-positive (matching _DZ_III_KERN and the pi2 +# kernels, which already guard the same way). _DZ_DISTINCT_KERN = cp.RawKernel(r''' extern "C" __global__ void k(const double*ns,const double*D,const double*A,const double*Bv, const int*I,const int*J,const int*K,double*out,const int N){ int t=blockDim.x*blockIdx.x+threadIdx.x; if(t>=N)return; int i=I[t],j=J[t],k=K[t]; - out[t]=2.*D[i]*Bv[j]*A[k]/(ns[i]*(ns[i]-1.)*ns[j]*ns[k]); + double d=ns[i]*(ns[i]-1.)*ns[j]*ns[k]; + out[t]=(d>0.0)?2.*D[i]*Bv[j]*A[k]/d:0.0; }''', "k", options=("-std=c++11",)) _DZ_IIJ_KERN = cp.RawKernel(r''' @@ -211,7 +220,8 @@ def _launch(kernel, args, M): const int*I,const int*J,double*out,const int N){ int t=blockDim.x*blockIdx.x+threadIdx.x; if(t>=N)return; int i=I[t],j=J[t]; - out[t]=2.*Ucoef[i]*A[j]/(ns[j]*ns[i]*(ns[i]-1.)*(ns[i]-2.)); + double d=ns[j]*ns[i]*(ns[i]-1.)*(ns[i]-2.); + out[t]=(d>0.0)?2.*Ucoef[i]*A[j]/d:0.0; }''', "k", options=("-std=c++11",)) _DZ_IJI_KERN = cp.RawKernel(r''' @@ -220,7 +230,8 @@ def _launch(kernel, args, M): const int*I,const int*J,double*out,const int N){ int t=blockDim.x*blockIdx.x+threadIdx.x; if(t>=N)return; int i=I[t],j=J[t]; - out[t]=2.*Vcoef[i]*Bv[j]/(ns[j]*ns[i]*(ns[i]-1.)*(ns[i]-2.)); + double d=ns[j]*ns[i]*(ns[i]-1.)*(ns[i]-2.); + out[t]=(d>0.0)?2.*Vcoef[i]*Bv[j]/d:0.0; }''', "k", options=("-std=c++11",)) _DZ_IJJ_KERN = cp.RawKernel(r''' @@ -229,7 +240,8 @@ def _launch(kernel, args, M): const int*I,const int*J,double*out,const int N){ int t=blockDim.x*blockIdx.x+threadIdx.x; if(t>=N)return; int i=I[t],j=J[t]; - out[t]=2.*D[i]*Wjj[j]/(ns[j]*(ns[j]-1.)*ns[i]*(ns[i]-1.)); + double d=ns[j]*(ns[j]-1.)*ns[i]*(ns[i]-1.); + out[t]=(d>0.0)?2.*D[i]*Wjj[j]/d:0.0; }''', "k", options=("-std=c++11",)) _DZ_III_KERN = cp.RawKernel(r''' @@ -516,7 +528,14 @@ def __init__(self, pops): def flat(arrs): return cp.ascontiguousarray(cp.concatenate(arrs)) - inv_n = [1.0 / p.n for p in pops] + # Guard n==0 (a population with no individual genotyped at both loci of + # a pair, which missing data allows): 1/n would be inf and feed NaN into + # the normalized-frequency pi2 kernels (alldiff, iikl, ijkk, shared, + # none of which divide by a per-pair denominator that could catch it). + # The allele counts pA..qB are 0 whenever n is 0, so a maximum(n,1) + # divisor yields frequency 0 and the degenerate pair contributes + # nothing -- matching the falling-factorial guards used elsewhere. + inv_n = [1.0 / cp.maximum(p.n, 1.0) for p in pops] # Normalized frequencies for between-pop factored kernels self.p = flat([pops[i].pA * inv_n[i] for i in range(P)]) self.r = flat([pops[i].qA * inv_n[i] for i in range(P)]) diff --git a/pg_gpu/moments_ld.py b/pg_gpu/moments_ld.py index b08bc832..62c853a6 100644 --- a/pg_gpu/moments_ld.py +++ b/pg_gpu/moments_ld.py @@ -385,18 +385,24 @@ def _compute_heterozygosity(mat, pops, use_genotypes=False): ref_counts.append(n_hap - alt) hap_sizes.append(n_hap) + # A site where a population has no genotyped individual (n_hap == 0, which + # missing data allows even after the union-biallelic filter) makes the + # haploid-size denominator zero. The alt/ref counts are also zero there, so + # the site's heterozygosity contribution is zero; clamp the denominator to + # avoid 0/0 = NaN poisoning the per-site sum. result = {} for ii in range(num_pops): for jj in range(ii, num_pops): if ii == jj: + denom = cp.maximum(hap_sizes[ii] * (hap_sizes[ii] - 1), 1.0) val = float(cp.sum( - 2.0 * ref_counts[ii] * alt_counts[ii] - / (hap_sizes[ii] * (hap_sizes[ii] - 1)) + 2.0 * ref_counts[ii] * alt_counts[ii] / denom ).get()) else: + denom = cp.maximum(hap_sizes[ii] * hap_sizes[jj], 1.0) val = float(cp.sum( (ref_counts[ii] * alt_counts[jj] + alt_counts[ii] * ref_counts[jj]) - / (hap_sizes[ii] * hap_sizes[jj]) + / denom ).get()) result[f"H_{ii}_{jj}"] = val diff --git a/tests/test_ld_missing_data_degenerate.py b/tests/test_ld_missing_data_degenerate.py new file mode 100644 index 00000000..a8f0fd20 --- /dev/null +++ b/tests/test_ld_missing_data_degenerate.py @@ -0,0 +1,159 @@ +"""Genotype LD with degenerate per-pair / per-site sample counts (issue #123). + +Two layers: + +* Correctness -- pg_gpu's genotype kernels match moments' Python reference + (``stats_from_genotype_counts``, with the PR #258 ``-20`` pi2(i,i;i,i) + coefficient) to machine precision on well-sampled count vectors, applying the + same index symmetrization moments' pipeline uses. This needs the moments env. + +* Robustness -- when missing data drops a population's per-pair sample count + below the order of a statistic's denominator (``n*(n-1)*(n-2)`` etc.), the + per-pair statistic is 0/0, which is *undefined* in moments too (its formula + raises / yields nan). pg_gpu contributes 0 there -- excluding the undefined + pair -- instead of poisoning the whole r-bin sum with a NaN. These checks + need only a GPU. +""" +import numpy as np +import cupy as cp +import pytest + +from pg_gpu import GenotypeMatrix +from pg_gpu.genotype_kernels import compute_multi_pop_statistics_batch_geno +from pg_gpu.ld_pipeline import compute_genotype_counts_for_pairs +from pg_gpu.moments_ld import _compute_heterozygosity, _generate_stat_specs, _ld_names + +try: + from moments.LD import stats_from_genotype_counts as sgc + HAVE_MOMENTS = True +except Exception: + HAVE_MOMENTS = False + + +# pg_gpu's count columns are (n00,n01,n02,n10,n11,n12,n20,n21,n22); moments' +# n1..n9 are (n22,n21,...,n00) -- the reverse. +def _to_pg(mom9): + return np.asarray(mom9, dtype=np.int32)[::-1] + + +def _moments_ref(name, counts): + """Reference value for ``name`` from moments' Python per-pair functions, + with the same symmetrization ``moments.LD.Parsing._call_sgc`` applies. + ``counts`` is a list of per-population moments-order 9-count vectors.""" + parts = name.split("_") + kind, idx = parts[0], [int(x) for x in parts[1:]] + if kind == "DD": + return sgc.DD(counts, idx) + if kind == "Dz": + i, j, k = idx + if j == k: + return sgc.Dz(counts, [i, j, k]) + return 0.5 * (sgc.Dz(counts, [i, j, k]) + sgc.Dz(counts, [i, k, j])) + i, j, k, l = idx + if i == j and k == l and i == k: + return sgc.pi2(counts, [i, j, k, l]) + if i == j and k == l: + return 0.5 * (sgc.pi2(counts, [i, j, k, l]) + sgc.pi2(counts, [k, l, i, j])) + if i == j: + return 0.25 * sum(sgc.pi2(counts, o) for o in + ([i, j, k, l], [i, j, l, k], [k, l, i, j], [l, k, i, j])) + if k == l: + return 0.25 * sum(sgc.pi2(counts, o) for o in + ([i, j, k, l], [j, i, k, l], [k, l, i, j], [k, l, j, i])) + return (1.0 / 8) * sum(sgc.pi2(counts, o) for o in ( + [i, j, k, l], [i, j, l, k], [j, i, k, l], [j, i, l, k], + [k, l, i, j], [l, k, i, j], [k, l, j, i], [l, k, j, i])) + + +@pytest.mark.skipif(not HAVE_MOMENTS, reason="needs the moments pixi env") +def test_genotype_stats_match_moments_python(): + # Random, well-sampled 3-population genotype counts (n>=8 per pair) exercise + # every DD/Dz/pi2 case, including the between-population kernels. pg_gpu must + # equal moments' Python reference to floating-point precision. + rng = np.random.default_rng(0) + n_pops, n_pairs = 3, 300 + counts_mom, counts_pg = [], [] + for _ in range(n_pops): + mom = [rng.multinomial(int(rng.integers(8, 40)), np.ones(9) / 9) + for _ in range(n_pairs)] + counts_mom.append(mom) + counts_pg.append(cp.asarray(np.array([_to_pg(c) for c in mom]), + dtype=cp.int32)) + + names = _ld_names(n_pops) + gpu = cp.asnumpy(compute_multi_pop_statistics_batch_geno( + counts_pg, [None] * n_pops, None, _generate_stat_specs(n_pops))) + + for k, nm in enumerate(names): + ref = np.array([_moments_ref(nm, [counts_mom[p][pp] for p in range(n_pops)]) + for pp in range(n_pairs)]) + np.testing.assert_allclose(gpu[:, k], ref, rtol=1e-9, atol=1e-12, + err_msg=f"{nm} disagrees with moments") + + +def _pop_counts(g, pop_idx, n_var): + i, j = cp.triu_indices(n_var, k=1) + return compute_genotype_counts_for_pairs( + cp.asarray(g, dtype=cp.int8), + i.astype(cp.int32), j.astype(cp.int32), + cp.asarray(pop_idx, dtype=cp.int32)) + + +def test_between_pop_dz_degenerate_pair_is_finite(): + # pop0 small, with missing data leaving one individual typed at both loci of + # the pair -- the n*(n-1)*(n-2)=0 case. pops 1, 2 fully typed. + M = -1 + g = np.array([ + [1, 1], [0, M], [M, 2], [M, M], + [0, 1], [1, 1], [2, 0], [1, 2], [0, 0], [1, 1], [2, 1], [0, 2], + [1, 0], [0, 1], [1, 1], [2, 2], [0, 1], [1, 0], [2, 1], [1, 1], + ], dtype=np.int8) + pops = [list(range(0, 4)), list(range(4, 12)), list(range(12, 20))] + counts, n_valid = zip(*(_pop_counts(g, p, g.shape[1]) for p in pops)) + assert int(counts[0].sum()) == 1 + + stats = cp.asnumpy(compute_multi_pop_statistics_batch_geno( + list(counts), list(n_valid), None, _generate_stat_specs(3))) + assert np.all(np.isfinite(stats)) + + names = _ld_names(3) + for nm in ("Dz_0_0_1", "Dz_0_1_1", "Dz_0_1_2"): + assert stats[0, names.index(nm)] == 0.0 + + +def test_between_pop_pi2_zero_sample_pop_is_finite(): + # pop0 entirely missing at the second locus -> zero jointly-typed -> the + # normalized-frequency pi2 kernels (alldiff, iikl, ijkk, shared) would emit + # NaN from 1/n=inf without the guard. + M = -1 + g = np.array([ + [1, M], [0, M], [2, M], + [0, 1], [1, 1], [2, 0], [1, 2], + [1, 0], [0, 1], [1, 1], [2, 2], + [0, 2], [1, 0], [2, 1], [1, 1], + ], dtype=np.int8) + pops = [list(range(0, 3)), list(range(3, 7)), + list(range(7, 11)), list(range(11, 15))] + counts, n_valid = zip(*(_pop_counts(g, p, g.shape[1]) for p in pops)) + assert int(counts[0].sum()) == 0 + + stats = cp.asnumpy(compute_multi_pop_statistics_batch_geno( + list(counts), list(n_valid), None, _generate_stat_specs(4))) + assert np.all(np.isfinite(stats)) + + names = _ld_names(4) + for nm in ("pi2_0_1_2_3", "pi2_0_2_1_3"): + assert stats[0, names.index(nm)] == 0.0 + + +def test_heterozygosity_zero_sample_pop_site_is_finite(): + # A site where a whole population is missing makes the haploid-size + # denominator zero; the per-site contribution is 0, not NaN. + M = -1 + geno = np.array([[1, M], [0, M], [2, M], + [0, 1], [1, 0], [2, 1], [1, 2]], dtype=np.int8) + gm = GenotypeMatrix(geno, np.array([100, 200]), 100, 200) + gm.transfer_to_gpu() + gm.sample_sets = {"A": [0, 1, 2], "B": [3, 4, 5, 6]} + het = _compute_heterozygosity(gm, ["A", "B"], use_genotypes=True) + assert all(np.isfinite(v) for v in het.values())