diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 818450ca5..b06d4526e 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -154,6 +154,7 @@ AbstractSphericalHarmonicsDistribution, ) from .hypersphere_subset.bingham_distribution import BinghamDistribution +from .hypersphere_subset.complex_bingham_distribution import ComplexBinghamDistribution from .hypersphere_subset.custom_hemispherical_distribution import ( CustomHemisphericalDistribution, ) @@ -341,6 +342,7 @@ "AbstractSphericalDistribution", "AbstractSphericalHarmonicsDistribution", "BinghamDistribution", + "ComplexBinghamDistribution", "CustomHemisphericalDistribution", "CustomHyperhemisphericalDistribution", "CustomHypersphericalDistribution", diff --git a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py new file mode 100644 index 000000000..7df549cf0 --- /dev/null +++ b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py @@ -0,0 +1,384 @@ +# pylint: disable=no-name-in-module,no-member,redefined-builtin +"""Complex Bingham Distribution. + +Ported from the MATLAB libDirectional library: + ComplexBinghamDistribution.m (libDirectional/lib/distributions/complexHypersphere/) + +Reference: + Kent, J. T. "The Complex Bingham Distribution and Shape Analysis." + Journal of the Royal Statistical Society. Series B (Methodological), 1994, 285-299. +""" +from scipy.optimize import least_squares + +from pyrecest.backend import ( + abs, + all, + allclose, + arange, + argsort, + array, + asarray, + complex128, + concatenate, + conj, + cumsum, + diff, + diag, + einsum, + empty, + exp, + linalg, + linspace, + log, + max, + maximum, + minimum, + pi, + prod, + random, + real, + sign, + sort, + sqrt, + sum, + zeros, +) + + +class ComplexBinghamDistribution: + """Complex Bingham distribution on the complex unit hypersphere. + + The distribution is defined on the complex unit sphere + S^{2d-1} = {z in C^d : ||z|| = 1} with pdf + + p(z) proportional to exp(z^H B z), + + where B is a d x d Hermitian parameter matrix. + + Attributes + ---------- + B : array, shape (d, d), dtype complex128 + Hermitian parameter matrix. + dim : int + Complex dimension d. + log_norm_const : float + Negative log normalization constant (i.e. -log C(B) where C(B) is + the normalising constant), stored so that + pdf(z) = exp(log_norm_const + Re(z^H B z)). + """ + + def __init__(self, B): + """Construct a ComplexBinghamDistribution. + + Parameters + ---------- + B : array_like, shape (d, d) + Hermitian parameter matrix (B == B^H must hold). + """ + B = asarray(B, dtype=complex128) + assert allclose(B, conj(B).T, atol=1e-10), "B must be Hermitian" + self.B = B + self.dim = B.shape[0] + self.log_norm_const = ComplexBinghamDistribution.log_norm(B) + + # ------------------------------------------------------------------ + # Core methods + # ------------------------------------------------------------------ + + def pdf(self, z): + """Evaluate the pdf at one or more points on the complex unit sphere. + + Parameters + ---------- + z : array_like, shape (d,) or (d, n) + Each column (or the single vector) is a point on the complex unit + sphere. + + Returns + ------- + array, shape (n,) or float + Pdf value(s). + """ + z = asarray(z, dtype=complex128) + single = z.ndim == 1 + if single: + z = z[:, None] + # Re(z^H B z) for each column + Bz = self.B @ z # (d, n) + vals = real(einsum("ij,ij->j", conj(z), Bz)) # shape (n,) + p = exp(self.log_norm_const + vals) + return float(p[0]) if single else p + + def sample(self, n): + """Draw samples from the complex Bingham distribution. + + Uses the rejection-sampling algorithm of + Kent, Constable & Er (2004) "Simulation for the complex Bingham + distribution", Statistics and Computing, 14, 53-57. + + Parameters + ---------- + n : int + Number of samples. + + Returns + ------- + array, shape (d, n), dtype complex128 + Sampled unit vectors (each column has unit norm). + """ + d = self.dim + if d < 2: + raise ValueError("Sampling requires d >= 2.") + + # Eigendecomposition of -B (eigenvectors V, eigenvalues Lambda of -B) + eigenvalues_neg, V = linalg.eigh(-self.B) # sorted ascending + # Sort descending + idx = argsort(eigenvalues_neg)[::-1] + eigenvalues_neg = eigenvalues_neg[idx] + V = V[:, idx] + + # Shift so the last (smallest) eigenvalue of -B becomes 0 + Lambda = eigenvalues_neg - eigenvalues_neg[-1] # Lambda >= 0, Lambda[-1] = 0 + + # Precompute for truncated-exponential CDF inversion + Lam = Lambda[:-1] # shape (d-1,) + + samples = zeros((d, n), dtype=complex128) + for k in range(n): + # Rejection loop + while True: + S = zeros(d) + U = random.uniform(size=(int(d - 1),)) + for i in range(d - 1): + if Lambda[i] < 0.03: + # Nearly-uniform truncated exponential + S[i] = U[i] + else: + S[i] = -(1.0 / Lam[i]) * log( + 1.0 - U[i] * (1.0 - exp(-Lam[i])) + ) + if sum(S[:-1]) < 1.0: + break + S[-1] = 1.0 - sum(S[:-1]) + + # Random phases + theta = 2.0 * pi * random.uniform(size=(int(d),)) + weighted_phases = sqrt(S) * exp(1j * theta) + samples[:, k] = V @ weighted_phases + + return samples + + # ------------------------------------------------------------------ + # Static / class methods + # ------------------------------------------------------------------ + + @staticmethod + def log_norm(B): + """Compute the *negative* log normalization constant -log C(B). + + The formula used is (Kent 1994, eq. 2.3) after eigenvalue shift: + + C(lambda) = 2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k) + + For nearly-equal eigenvalues a small perturbation is applied so that + the Vandermonde denominators remain well-conditioned, matching the + approach in the original MATLAB implementation. + + Parameters + ---------- + B : array_like, shape (d, d) + Hermitian parameter matrix. + + Returns + ------- + float + Negative log normalization constant. + """ + B = asarray(B, dtype=complex128) + d = B.shape[0] + + # Real eigenvalues of a Hermitian matrix + eigenvalues = linalg.eigvalsh(B) # sorted ascending + + # Shift so the maximum eigenvalue is 0 + eigenvalue_shift = float(eigenvalues[-1]) + eigenvalues = eigenvalues - eigenvalue_shift + + # Perturb near-equal eigenvalues for numerical stability + eigenvalues = ComplexBinghamDistribution._perturb_eigenvalues(eigenvalues) + + if all(abs(eigenvalues) < 1e-3): + # All eigenvalues near zero: use limiting uniform value C = 2*pi^d / (d-1)! + log_C_shifted = log(2.0) + d * log(pi) - float( + sum(log(arange(1, d))) + ) + else: + # Analytical formula: C = 2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k) + log_C_shifted = ComplexBinghamDistribution._log_norm_from_eigenvalues( + eigenvalues + ) + + # Apply eigenvalue-shift correction: C(lambda) = exp(shift) * C(lambda - shift) + log_C = log_C_shifted + eigenvalue_shift + + return -log_C # inverted convention: log_norm_const = -log C + + @staticmethod + def _log_norm_from_eigenvalues(eigenvalues): + """Log normalization for shifted eigenvalues (max = 0) via partial fractions. + + Computes log(2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k)). + """ + d = len(eigenvalues) + log_prefix = log(2.0) + d * log(pi) + + # For each j compute sign_j * exp(log_term_j) where + # log_term_j = lambda_j - sum_{k!=j} log|lambda_j - lambda_k| + # diff_matrix[j, k] = eigenvalues[j] - eigenvalues[k] for j != k + diff_matrix = eigenvalues[:, None] - eigenvalues[None, :] + log_terms = empty(d) + signs = empty(d) + for j in range(d): + mask = arange(d) != j + diffs = diff_matrix[j, mask] + signs[j] = prod(sign(diffs)) + log_terms[j] = eigenvalues[j] - sum(log(abs(diffs))) + + # log(|sum_j sign_j * exp(log_term_j)|) via log-sum-exp + max_log = max(log_terms) + scaled = sum(signs * exp(log_terms - max_log)) + log_sum = max_log + log(abs(scaled)) + + return log_prefix + log_sum + + @staticmethod + def _perturb_eigenvalues(eigenvalues): + """Sort eigenvalues descending and enforce minimum spacing of 0.01. + + Mirrors MATLAB's makeSureEigenvaluesAreNotTooClose. + """ + lam = sort(eigenvalues)[::-1].copy() + diffs = diff(lam) # non-positive for sorted-descending + diffs = minimum(diffs, -0.01) # enforce gap >= 0.01 + lam[1:] = lam[0] + cumsum(diffs) + return lam + + @classmethod + def fit(cls, Z): + """Maximum-likelihood fit of a complex Bingham distribution to data. + + Parameters + ---------- + Z : array_like, shape (d, n) + Complex unit vectors (columns); it is assumed that ||z_i|| = 1. + + Returns + ------- + ComplexBinghamDistribution + """ + Z = asarray(Z, dtype=complex128) + n = Z.shape[1] + S = Z @ conj(Z).T / n # sample scatter matrix + B = cls._estimate_parameter_matrix(S) + return cls(B) + + @staticmethod + def _estimate_parameter_matrix(S): + """Compute the ML estimate of B from the scatter matrix S. + + The eigenvectors of B equal those of S. The eigenvalues of B are + found by solving the moment-matching equations + + d log C(lambda) / d lambda_k = s_k, k = 1, ..., d + + where s_k are the eigenvalues of S, using a least-squares solver + with finite-difference gradients. + + Parameters + ---------- + S : array, shape (d, d), complex Hermitian + Sample scatter matrix (E[z z^H] estimate). + + Returns + ------- + array, shape (d, d), complex Hermitian + Estimated parameter matrix B. + """ + d = S.shape[0] + eigenvalues_S, V = linalg.eigh(S) # ascending + + def grad_log_c(lam): + """Numerical gradient of log C via forward finite differences.""" + eps = 1e-6 + B_diag = diag(array(lam, dtype=complex128)) + log_c0 = ComplexBinghamDistribution.log_norm(B_diag) + grad = empty(d) + for i in range(d): + lam_p = lam.copy() + lam_p[i] += eps + log_cp = ComplexBinghamDistribution.log_norm( + diag(array(lam_p, dtype=complex128)) + ) + # log_norm_const = -log C, so d(log C)/dλ_i = -d(log_norm_const)/dλ_i + grad[i] = (-log_cp - (-log_c0)) / eps + return grad + + # Initial guess: spread eigenvalues below zero + initial_eigenvalues = linspace(-(d - 1) * 10, -10, int(d - 1)) + + def residuals(x): + lam = concatenate([x, array([0.0])]) + return grad_log_c(lam) - eigenvalues_S + + result = least_squares( + residuals, + initial_eigenvalues, + method="lm", + ftol=1e-15, + xtol=1e-10, + max_nfev=int(1e4), + ) + lam_B = concatenate([result.x, array([0.0])]) + B = V @ diag(array(lam_B, dtype=complex128)) @ conj(V).T + B = 0.5 * (B + conj(B).T) # enforce exact Hermitian symmetry + return B + + @staticmethod + def cauchy_schwarz_divergence(cB1, cB2): + """Cauchy-Schwarz divergence between two complex Bingham distributions. + + D_CS(p, q) = half * [log C(2*B1) + log C(2*B2)] - log C(B1+B2) >= 0 + + Using the stored negative log normalization (log_norm = -log C): + + D_CS = log_norm(B1+B2) - half * [log_norm(2*B1) + log_norm(2*B2)] + + This matches the MATLAB libDirectional CauchySchwarzDivergence convention. + + Parameters + ---------- + cB1, cB2 : ComplexBinghamDistribution or array_like + Distributions or Hermitian parameter matrices. + + Returns + ------- + float + Non-negative divergence value. + """ + if isinstance(cB1, ComplexBinghamDistribution): + B1 = cB1.B + else: + B1 = asarray(cB1, dtype=complex128) + if isinstance(cB2, ComplexBinghamDistribution): + B2 = cB2.B + else: + B2 = asarray(cB2, dtype=complex128) + + assert allclose(B1, conj(B1).T, atol=1e-10), "B1 must be Hermitian" + assert allclose(B2, conj(B2).T, atol=1e-10), "B2 must be Hermitian" + + log_c1 = ComplexBinghamDistribution.log_norm(2.0 * B1) + log_c2 = ComplexBinghamDistribution.log_norm(2.0 * B2) + log_c3 = ComplexBinghamDistribution.log_norm(B1 + B2) + + return log_c3 - 0.5 * (log_c1 + log_c2) diff --git a/pyrecest/tests/distributions/test_complex_bingham_distribution.py b/pyrecest/tests/distributions/test_complex_bingham_distribution.py new file mode 100644 index 000000000..096507b59 --- /dev/null +++ b/pyrecest/tests/distributions/test_complex_bingham_distribution.py @@ -0,0 +1,197 @@ +import unittest + +import numpy.testing as npt +import pyrecest.backend + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import ( + array, + diag, + exp, + linalg, + log, + mean, + pi, + ones, + random, + real, + sort, + sqrt, +) + +from pyrecest.distributions import ComplexBinghamDistribution + + +class TestComplexBinghamDistribution(unittest.TestCase): + """Tests for ComplexBinghamDistribution.""" + + def setUp(self): + # Simple 2x2 diagonal Hermitian B + self.B2 = diag(array([-3.0, 0.0], dtype=complex)) + self.cB2 = ComplexBinghamDistribution(self.B2) + + # 3x3 diagonal Hermitian B + self.B3 = diag(array([-5.0, -2.0, 0.0], dtype=complex)) + self.cB3 = ComplexBinghamDistribution(self.B3) + + def test_constructor_hermitian_check(self): + """Non-Hermitian matrix should raise AssertionError.""" + with self.assertRaises(AssertionError): + ComplexBinghamDistribution( + array([[1.0, 1j], [0.0, 1.0]]) + ) + + def test_log_norm_const_finite(self): + """log_norm_const must be finite.""" + import math + self.assertTrue(math.isfinite(self.cB2.log_norm_const)) + self.assertTrue(math.isfinite(self.cB3.log_norm_const)) + + def test_dim(self): + self.assertEqual(self.cB2.dim, 2) + self.assertEqual(self.cB3.dim, 3) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_pdf_normalises_to_one_2d(self): + """MC check: 2-D pdf integrates to 1 over S^3.""" + random.seed(12345) + # Sample uniformly from S^3 using complex Gaussian projection + real_part = random.normal(size=(2, 200_000)) + imag_part = random.normal(size=(2, 200_000)) + raw = real_part + 1j * imag_part + norms = linalg.norm(raw, axis=0, keepdims=True) + Z = raw / norms + area = 2.0 * float(pi) ** 2 # surface area of S^3 + mc_integral = float(mean(self.cB2.pdf(Z))) * area + npt.assert_almost_equal(mc_integral, 1.0, decimal=2) + + def test_pdf_positive(self): + """pdf must return positive values.""" + z = array([1.0, 0.0], dtype=complex) + self.assertGreater(self.cB2.pdf(z), 0.0) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_pdf_batch_vs_single(self): + """Vectorised pdf matches point-by-point evaluation.""" + random.seed(0) + real_part = random.normal(size=(2, 10)) + imag_part = random.normal(size=(2, 10)) + raw = real_part + 1j * imag_part + Z = raw / linalg.norm(raw, axis=0, keepdims=True) + batch = self.cB2.pdf(Z) + single = array([self.cB2.pdf(Z[:, k]) for k in range(10)]) + npt.assert_allclose(batch, single, rtol=1e-10) + + def test_pdf_invariant_to_phase(self): + """pdf(exp(i*alpha)*z) == pdf(z) for any global phase alpha.""" + z = array([0.6 + 0.3j, 0.0], dtype=complex) + z[1] = sqrt(1 - abs(z[0]) ** 2) + p0 = self.cB2.pdf(z) + for alpha in [0.3, 1.0, float(pi)]: + p1 = self.cB2.pdf(exp(1j * alpha) * z) + npt.assert_almost_equal(p1, p0, decimal=10) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_sample_shape(self): + """sample returns the right shape.""" + random.seed(42) + S = self.cB2.sample(50) + self.assertEqual(S.shape, (2, 50)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_sample_unit_norm(self): + """All samples must lie on the unit sphere.""" + random.seed(42) + S = self.cB2.sample(100) + norms = linalg.norm(S, axis=0) + npt.assert_allclose(norms, ones(100), atol=1e-12) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_sample_3d_unit_norm(self): + """3-D samples also lie on the unit sphere.""" + random.seed(7) + S = self.cB3.sample(50) + norms = linalg.norm(S, axis=0) + npt.assert_allclose(norms, ones(50), atol=1e-12) + + def test_log_norm_2d_analytic(self): + a = 3.0 + B = diag(array([-a, 0.0], dtype=complex)) + log_C_expected = float(log(2 * pi**2 / a * (1 - exp(-a)))) + log_norm_got = ComplexBinghamDistribution.log_norm(B) + npt.assert_almost_equal(-log_norm_got, log_C_expected, decimal=6) + + def test_log_norm_equal_eigenvalues(self): + """Equal eigenvalues (uniform) should not raise.""" + from pyrecest.backend import zeros + B = zeros((3, 3), dtype=complex) + log_norm = ComplexBinghamDistribution.log_norm(B) + import math + self.assertTrue(math.isfinite(log_norm)) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_fit_returns_instance(self): + """fit() returns a ComplexBinghamDistribution instance.""" + random.seed(0) + Z = self.cB2.sample(200) + cB_fit = ComplexBinghamDistribution.fit(Z) + self.assertIsInstance(cB_fit, ComplexBinghamDistribution) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) + def test_fit_recovers_eigenvalue_gap(self): + """Fit recovers the eigenvalue gap (up to additive const).""" + random.seed(0) + B = diag(array([-10.0, 0.0], dtype=complex)) + cB = ComplexBinghamDistribution(B) + Z = cB.sample(2000) + cB_fit = ComplexBinghamDistribution.fit(Z) + evals_fit = sort(real(linalg.eigvalsh(cB_fit.B))) + evals_true = sort(real(linalg.eigvalsh(B))) + gap_true = float(evals_true[-1] - evals_true[0]) + gap_fit = float(evals_fit[-1] - evals_fit[0]) + npt.assert_almost_equal(gap_fit, gap_true, decimal=0) + + def test_cauchy_schwarz_zero_for_identical(self): + """D_CS(p, p) should be 0.""" + d = ComplexBinghamDistribution.cauchy_schwarz_divergence(self.cB2, self.cB2) + npt.assert_almost_equal(d, 0.0, decimal=6) + + def test_cauchy_schwarz_symmetric(self): + """D_CS(p, q) == D_CS(q, p).""" + B3a = diag(array([-5.0, -1.0, 0.0], dtype=complex)) + B3b = diag(array([-3.0, -2.0, 0.0], dtype=complex)) + d_ab = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) + d_ba = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3b, B3a) + npt.assert_almost_equal(d_ab, d_ba, decimal=6) + + def test_cauchy_schwarz_nonneg(self): + """Cauchy-Schwarz divergence must be >= 0.""" + B3a = diag(array([-5.0, -1.0, 0.0], dtype=complex)) + B3b = diag(array([-3.0, -2.0, 0.0], dtype=complex)) + d = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) + self.assertGreaterEqual(d, -1e-10) + + +if __name__ == "__main__": + unittest.main()