From 081cb59a10a1d232418e3081a05a5e8c7be891c6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 05:48:06 +0000 Subject: [PATCH 1/4] Port ComplexBinghamDistribution from libDirectional MATLAB to Python Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/f1b48e18-0fb4-4c4f-a65b-0fa91ff8cb1c Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> --- pyrecest/distributions/__init__.py | 2 + .../complex_bingham_distribution.py | 342 ++++++++++++++++++ .../test_complex_bingham_distribution.py | 142 ++++++++ 3 files changed, 486 insertions(+) create mode 100644 pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py create mode 100644 pyrecest/tests/distributions/test_complex_bingham_distribution.py diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 818450ca5..b06d4526e 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -154,6 +154,7 @@ AbstractSphericalHarmonicsDistribution, ) from .hypersphere_subset.bingham_distribution import BinghamDistribution +from .hypersphere_subset.complex_bingham_distribution import ComplexBinghamDistribution from .hypersphere_subset.custom_hemispherical_distribution import ( CustomHemisphericalDistribution, ) @@ -341,6 +342,7 @@ "AbstractSphericalDistribution", "AbstractSphericalHarmonicsDistribution", "BinghamDistribution", + "ComplexBinghamDistribution", "CustomHemisphericalDistribution", "CustomHyperhemisphericalDistribution", "CustomHypersphericalDistribution", diff --git a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py new file mode 100644 index 000000000..edd2851a1 --- /dev/null +++ b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py @@ -0,0 +1,342 @@ +# pylint: disable=no-name-in-module +"""Complex Bingham Distribution. + +Ported from the MATLAB libDirectional library: + ComplexBinghamDistribution.m (libDirectional/lib/distributions/complexHypersphere/) + +Reference: + Kent, J. T. "The Complex Bingham Distribution and Shape Analysis." + Journal of the Royal Statistical Society. Series B (Methodological), 1994, 285-299. +""" +import numpy as np +from scipy.linalg import eigh +from scipy.optimize import least_squares + + +class ComplexBinghamDistribution: + """Complex Bingham distribution on the complex unit hypersphere. + + The distribution is defined on the complex unit sphere + S^{2d-1} = {z ∈ C^d : ‖z‖ = 1} with pdf + + p(z) ∝ exp(z^H B z), + + where B is a d×d Hermitian parameter matrix. + + Attributes + ---------- + B : numpy.ndarray, shape (d, d), dtype complex128 + Hermitian parameter matrix. + dim : int + Complex dimension d. + log_norm_const : float + Negative log normalization constant (i.e. −log C(B) where C(B) is + the normalising constant), stored so that + pdf(z) = exp(log_norm_const + Re(z^H B z)). + """ + + def __init__(self, B): + """Construct a ComplexBinghamDistribution. + + Parameters + ---------- + B : array_like, shape (d, d) + Hermitian parameter matrix (B == B^H must hold). + """ + B = np.asarray(B, dtype=complex) + assert np.allclose(B, B.conj().T, atol=1e-10), "B must be Hermitian" + self.B = B + self.dim = B.shape[0] + self.log_norm_const = ComplexBinghamDistribution.log_norm(B) + + # ------------------------------------------------------------------ + # Core methods + # ------------------------------------------------------------------ + + def pdf(self, z): + """Evaluate the pdf at one or more points on the complex unit sphere. + + Parameters + ---------- + z : array_like, shape (d,) or (d, n) + Each column (or the single vector) is a point on the complex unit + sphere. + + Returns + ------- + numpy.ndarray, shape (n,) or float + Pdf value(s). + """ + z = np.asarray(z, dtype=complex) + single = z.ndim == 1 + if single: + z = z[:, np.newaxis] + # Re(z^H B z) for each column + Bz = self.B @ z # (d, n) + vals = np.real(np.einsum("ij,ij->j", z.conj(), Bz)) # shape (n,) + p = np.exp(self.log_norm_const + vals) + return float(p[0]) if single else p + + def sample(self, n): + """Draw samples from the complex Bingham distribution. + + Uses the rejection-sampling algorithm of + Kent, Constable & Er (2004) "Simulation for the complex Bingham + distribution", *Statistics and Computing*, 14, 53-57. + + Parameters + ---------- + n : int + Number of samples. + + Returns + ------- + numpy.ndarray, shape (d, n), dtype complex128 + Sampled unit vectors (each column has unit norm). + """ + d = self.dim + if d < 2: + raise ValueError("Sampling requires d >= 2.") + + # Eigendecomposition of -B (eigenvectors V, eigenvalues Λ of -B) + eigenvalues_neg, V = eigh(-self.B) # sorted ascending by scipy + # Sort descending + idx = np.argsort(eigenvalues_neg)[::-1] + eigenvalues_neg = eigenvalues_neg[idx] + V = V[:, idx] + + # Shift so the last (smallest) eigenvalue of -B becomes 0 + Lambda = eigenvalues_neg - eigenvalues_neg[-1] # Λ ≥ 0, Λ[-1] = 0 + + # Precompute for truncated-exponential CDF inversion + Lam = Lambda[:-1] # shape (d-1,) + + samples = np.zeros((d, n), dtype=complex) + for k in range(n): + # Rejection loop + while True: + S = np.zeros(d) + U = np.random.uniform(size=d - 1) + for i in range(d - 1): + if Lambda[i] < 0.03: + # Nearly-uniform truncated exponential + S[i] = U[i] + else: + S[i] = -(1.0 / Lam[i]) * np.log( + 1.0 - U[i] * (1.0 - np.exp(-Lam[i])) + ) + if S[:-1].sum() < 1.0: + break + S[-1] = 1.0 - S[:-1].sum() + + # Random phases + theta = 2.0 * np.pi * np.random.uniform(size=d) + W = np.sqrt(S) * np.exp(1j * theta) + samples[:, k] = V @ W + + return samples + + # ------------------------------------------------------------------ + # Static / class methods + # ------------------------------------------------------------------ + + @staticmethod + def log_norm(B): + """Compute the *negative* log normalization constant −log C(B). + + The formula used is (Kent 1994, eq. 2.3) after eigenvalue shift: + + C(λ) = 2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j − λ_k) + + For nearly-equal eigenvalues a small perturbation is applied so that + the Vandermonde denominators remain well-conditioned, matching the + approach in the original MATLAB implementation. + + Parameters + ---------- + B : array_like, shape (d, d) + Hermitian parameter matrix. + + Returns + ------- + float + Negative log normalization constant. + """ + B = np.asarray(B, dtype=complex) + d = B.shape[0] + + # Real eigenvalues of a Hermitian matrix + eigenvalues = np.linalg.eigvalsh(B) # sorted ascending + + # Shift so the maximum eigenvalue is 0 + eigenvalue_shift = float(eigenvalues[-1]) + eigenvalues = eigenvalues - eigenvalue_shift + + # Perturb near-equal eigenvalues for numerical stability + eigenvalues = ComplexBinghamDistribution._perturb_eigenvalues(eigenvalues) + + if np.all(np.abs(eigenvalues) < 1e-3): + # All eigenvalues near zero: use limiting uniform value C = 2π^d / (d-1)! + log_C_shifted = np.log(2.0) + d * np.log(np.pi) - float( + np.sum(np.log(np.arange(1, d))) + ) + else: + # Analytical formula: C = 2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j - λ_k) + log_C_shifted = ComplexBinghamDistribution._log_norm_from_eigenvalues( + eigenvalues + ) + + # Apply eigenvalue-shift correction: C(λ) = exp(shift) · C(λ - shift) + log_C = log_C_shifted + eigenvalue_shift + + return -log_C # inverted convention: log_norm_const = -log C + + @staticmethod + def _log_norm_from_eigenvalues(eigenvalues): + """Log normalization for shifted eigenvalues (max = 0) via partial fractions. + + Computes log(2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j - λ_k)). + """ + d = len(eigenvalues) + log_prefix = np.log(2.0) + d * np.log(np.pi) + + # For each j compute sign_j * exp(log_term_j) where + # log_term_j = λ_j - Σ_{k≠j} log|λ_j - λ_k| + log_terms = np.empty(d) + signs = np.empty(d) + for j in range(d): + diffs = eigenvalues[j] - np.delete(eigenvalues, j) + signs[j] = np.prod(np.sign(diffs)) + log_terms[j] = eigenvalues[j] - np.sum(np.log(np.abs(diffs))) + + # log(|Σ_j sign_j · exp(log_term_j)|) via log-sum-exp + max_log = np.max(log_terms) + scaled = np.sum(signs * np.exp(log_terms - max_log)) + log_sum = max_log + np.log(np.abs(scaled)) + + return log_prefix + log_sum + + @staticmethod + def _perturb_eigenvalues(eigenvalues): + """Sort eigenvalues descending and enforce minimum spacing of 0.01. + + Mirrors MATLAB's ``makeSureEigenvaluesAreNotTooClose``. + """ + lam = np.sort(eigenvalues)[::-1].copy() + diffs = np.diff(lam) # non-positive for sorted-descending + diffs = np.minimum(diffs, -0.01) # enforce gap ≥ 0.01 + lam[1:] = lam[0] + np.cumsum(diffs) + return lam + + @classmethod + def fit(cls, Z): + """Maximum-likelihood fit of a complex Bingham distribution to data. + + Parameters + ---------- + Z : array_like, shape (d, n) + Complex unit vectors (columns); it is assumed that ‖z_i‖ = 1. + + Returns + ------- + ComplexBinghamDistribution + """ + Z = np.asarray(Z, dtype=complex) + n = Z.shape[1] + S = Z @ Z.conj().T / n # sample scatter matrix + B = cls._estimate_parameter_matrix(S) + return cls(B) + + @staticmethod + def _estimate_parameter_matrix(S): + """Compute the ML estimate of B from the scatter matrix S. + + The eigenvectors of B equal those of S. The eigenvalues of B are + found by solving the moment-matching equations + + ∂ log C(λ) / ∂λ_k = s_k, k = 1, …, d + + where s_k are the eigenvalues of S, using a least-squares solver + with finite-difference gradients. + + Parameters + ---------- + S : numpy.ndarray, shape (d, d), complex Hermitian + Sample scatter matrix (E[z z^H] estimate). + + Returns + ------- + numpy.ndarray, shape (d, d), complex Hermitian + Estimated parameter matrix B. + """ + d = S.shape[0] + eigenvalues_S, V = eigh(S) # ascending + + def grad_log_C(lam): + """Numerical gradient of log C via forward finite differences.""" + eps = 1e-6 + B_diag = np.diag(lam.astype(complex)) + log_c0 = ComplexBinghamDistribution.log_norm(B_diag) + grad = np.empty(d) + for i in range(d): + lam_p = lam.copy() + lam_p[i] += eps + log_cp = ComplexBinghamDistribution.log_norm( + np.diag(lam_p.astype(complex)) + ) + # log_norm_const = -log C, so d(log C)/dλ_i = -d(log_norm_const)/dλ_i + grad[i] = (-log_cp - (-log_c0)) / eps + return grad + + # Initial guess: spread eigenvalues below zero + x0 = np.linspace(-(d - 1) * 10, -10, d - 1) + + def residuals(x): + lam = np.append(x, 0.0) + return grad_log_C(lam) - eigenvalues_S + + result = least_squares( + residuals, + x0, + method="lm", + ftol=1e-15, + xtol=1e-10, + max_nfev=int(1e4), + ) + lam_B = np.append(result.x, 0.0) + B = V @ np.diag(lam_B.astype(complex)) @ V.conj().T + B = 0.5 * (B + B.conj().T) # enforce exact Hermitian symmetry + return B + + @staticmethod + def cauchy_schwarz_divergence(cB1, cB2): + """Cauchy-Schwarz divergence between two complex Bingham distributions. + + D_CS(p, q) = log C(B₁+B₂) − ½[log C(2B₁) + log C(2B₂)] + + Parameters + ---------- + cB1, cB2 : ComplexBinghamDistribution or numpy.ndarray + Distributions or parameter matrices. + + Returns + ------- + float + """ + if isinstance(cB1, ComplexBinghamDistribution): + B1 = cB1.B + else: + B1 = np.asarray(cB1, dtype=complex) + if isinstance(cB2, ComplexBinghamDistribution): + B2 = cB2.B + else: + B2 = np.asarray(cB2, dtype=complex) + + assert np.allclose(B1, B1.conj().T, atol=1e-10), "B1 must be Hermitian" + assert np.allclose(B2, B2.conj().T, atol=1e-10), "B2 must be Hermitian" + + log_c1 = ComplexBinghamDistribution.log_norm(2.0 * B1) + log_c2 = ComplexBinghamDistribution.log_norm(2.0 * B2) + log_c3 = ComplexBinghamDistribution.log_norm(B1 + B2) + + return log_c3 - 0.5 * (log_c1 + log_c2) diff --git a/pyrecest/tests/distributions/test_complex_bingham_distribution.py b/pyrecest/tests/distributions/test_complex_bingham_distribution.py new file mode 100644 index 000000000..21dcc2bc3 --- /dev/null +++ b/pyrecest/tests/distributions/test_complex_bingham_distribution.py @@ -0,0 +1,142 @@ +import unittest + +import numpy as np +import numpy.testing as npt + +from pyrecest.distributions import ComplexBinghamDistribution + + +class TestComplexBinghamDistribution(unittest.TestCase): + """Tests for ComplexBinghamDistribution.""" + + def setUp(self): + # Simple 2x2 diagonal Hermitian B + self.B2 = np.diag([-3.0, 0.0]).astype(complex) + self.cB2 = ComplexBinghamDistribution(self.B2) + + # 3x3 diagonal Hermitian B + self.B3 = np.diag([-5.0, -2.0, 0.0]).astype(complex) + self.cB3 = ComplexBinghamDistribution(self.B3) + + def test_constructor_hermitian_check(self): + """Non-Hermitian matrix should raise AssertionError.""" + with self.assertRaises(AssertionError): + ComplexBinghamDistribution(np.array([[1.0, 1j], [0.0, 1.0]])) + + def test_log_norm_const_finite(self): + """log_norm_const must be finite.""" + self.assertTrue(np.isfinite(self.cB2.log_norm_const)) + self.assertTrue(np.isfinite(self.cB3.log_norm_const)) + + def test_dim(self): + self.assertEqual(self.cB2.dim, 2) + self.assertEqual(self.cB3.dim, 3) + + def test_pdf_normalises_to_one_2d(self): + """MC check: 2-D pdf integrates to 1 over S^3.""" + rng = np.random.default_rng(12345) + raw = rng.standard_normal((2, 200_000)) + 1j * rng.standard_normal((2, 200_000)) + Z = raw / np.linalg.norm(raw, axis=0, keepdims=True) + area = 2.0 * np.pi**2 # surface area of S^3 + mc_integral = np.mean(self.cB2.pdf(Z)) * area + npt.assert_almost_equal(mc_integral, 1.0, decimal=2) + + def test_pdf_positive(self): + """pdf must return positive values.""" + z = np.array([1.0, 0.0], dtype=complex) + self.assertGreater(self.cB2.pdf(z), 0.0) + + def test_pdf_batch_vs_single(self): + """Vectorised pdf matches point-by-point evaluation.""" + rng = np.random.default_rng(0) + raw = rng.standard_normal((2, 10)) + 1j * rng.standard_normal((2, 10)) + Z = raw / np.linalg.norm(raw, axis=0, keepdims=True) + batch = self.cB2.pdf(Z) + single = np.array([self.cB2.pdf(Z[:, k]) for k in range(10)]) + npt.assert_allclose(batch, single, rtol=1e-10) + + def test_pdf_invariant_to_phase(self): + """pdf(exp(i*alpha)*z) == pdf(z) for any global phase alpha.""" + z = np.array([0.6 + 0.3j, 0.0], dtype=complex) + z[1] = np.sqrt(1 - np.abs(z[0]) ** 2) + p0 = self.cB2.pdf(z) + for alpha in [0.3, 1.0, np.pi]: + p1 = self.cB2.pdf(np.exp(1j * alpha) * z) + npt.assert_almost_equal(p1, p0, decimal=10) + + def test_sample_shape(self): + """sample returns the right shape.""" + np.random.seed(42) + S = self.cB2.sample(50) + self.assertEqual(S.shape, (2, 50)) + + def test_sample_unit_norm(self): + """All samples must lie on the unit sphere.""" + np.random.seed(42) + S = self.cB2.sample(100) + norms = np.linalg.norm(S, axis=0) + npt.assert_allclose(norms, np.ones(100), atol=1e-12) + + def test_sample_3d_unit_norm(self): + """3-D samples also lie on the unit sphere.""" + np.random.seed(7) + S = self.cB3.sample(50) + norms = np.linalg.norm(S, axis=0) + npt.assert_allclose(norms, np.ones(50), atol=1e-12) + + def test_log_norm_2d_analytic(self): + a = 3.0 + B = np.diag([-a, 0.0]).astype(complex) + log_C_expected = np.log(2 * np.pi**2 / a * (1 - np.exp(-a))) + log_norm_got = ComplexBinghamDistribution.log_norm(B) + npt.assert_almost_equal(-log_norm_got, log_C_expected, decimal=6) + + def test_log_norm_equal_eigenvalues(self): + """Equal eigenvalues (uniform) should not raise.""" + B = np.zeros((3, 3), dtype=complex) + log_norm = ComplexBinghamDistribution.log_norm(B) + self.assertTrue(np.isfinite(log_norm)) + + def test_fit_returns_instance(self): + """fit() returns a ComplexBinghamDistribution instance.""" + np.random.seed(0) + Z = self.cB2.sample(200) + cB_fit = ComplexBinghamDistribution.fit(Z) + self.assertIsInstance(cB_fit, ComplexBinghamDistribution) + + def test_fit_recovers_eigenvalue_gap(self): + """Fit recovers the eigenvalue gap (up to additive const).""" + np.random.seed(0) + B = np.diag([-10.0, 0.0]).astype(complex) + cB = ComplexBinghamDistribution(B) + Z = cB.sample(2000) + cB_fit = ComplexBinghamDistribution.fit(Z) + evals_fit = np.sort(np.real(np.linalg.eigvalsh(cB_fit.B))) + evals_true = np.sort(np.real(np.linalg.eigvalsh(B))) + gap_true = evals_true[-1] - evals_true[0] + gap_fit = evals_fit[-1] - evals_fit[0] + npt.assert_almost_equal(gap_fit, gap_true, decimal=0) + + def test_cauchy_schwarz_zero_for_identical(self): + """D_CS(p, p) should be 0.""" + d = ComplexBinghamDistribution.cauchy_schwarz_divergence(self.cB2, self.cB2) + npt.assert_almost_equal(d, 0.0, decimal=6) + + def test_cauchy_schwarz_symmetric(self): + """D_CS(p, q) == D_CS(q, p).""" + B3a = np.diag([-5.0, -1.0, 0.0]).astype(complex) + B3b = np.diag([-3.0, -2.0, 0.0]).astype(complex) + d_ab = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) + d_ba = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3b, B3a) + npt.assert_almost_equal(d_ab, d_ba, decimal=6) + + def test_cauchy_schwarz_nonneg(self): + """Cauchy-Schwarz divergence must be >= 0.""" + B3a = np.diag([-5.0, -1.0, 0.0]).astype(complex) + B3b = np.diag([-3.0, -2.0, 0.0]).astype(complex) + d = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) + self.assertGreaterEqual(d, -1e-10) + + +if __name__ == "__main__": + unittest.main() From dd66114e256bb2e2717a4e5f3f9a08aea4038b32 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 05:51:20 +0000 Subject: [PATCH 2/4] Address code review feedback: rename variables, optimize np.delete loop Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/f1b48e18-0fb4-4c4f-a65b-0fa91ff8cb1c Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> --- .../complex_bingham_distribution.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py index edd2851a1..666030939 100644 --- a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py @@ -131,8 +131,8 @@ def sample(self, n): # Random phases theta = 2.0 * np.pi * np.random.uniform(size=d) - W = np.sqrt(S) * np.exp(1j * theta) - samples[:, k] = V @ W + weighted_phases = np.sqrt(S) * np.exp(1j * theta) + samples[:, k] = V @ weighted_phases return samples @@ -202,10 +202,13 @@ def _log_norm_from_eigenvalues(eigenvalues): # For each j compute sign_j * exp(log_term_j) where # log_term_j = λ_j - Σ_{k≠j} log|λ_j - λ_k| + # diff_matrix[j, k] = eigenvalues[j] - eigenvalues[k] for j != k + diff_matrix = eigenvalues[:, np.newaxis] - eigenvalues[np.newaxis, :] log_terms = np.empty(d) signs = np.empty(d) for j in range(d): - diffs = eigenvalues[j] - np.delete(eigenvalues, j) + mask = np.arange(d) != j + diffs = diff_matrix[j, mask] signs[j] = np.prod(np.sign(diffs)) log_terms[j] = eigenvalues[j] - np.sum(np.log(np.abs(diffs))) @@ -272,7 +275,7 @@ def _estimate_parameter_matrix(S): d = S.shape[0] eigenvalues_S, V = eigh(S) # ascending - def grad_log_C(lam): + def grad_log_c(lam): """Numerical gradient of log C via forward finite differences.""" eps = 1e-6 B_diag = np.diag(lam.astype(complex)) @@ -289,15 +292,15 @@ def grad_log_C(lam): return grad # Initial guess: spread eigenvalues below zero - x0 = np.linspace(-(d - 1) * 10, -10, d - 1) + initial_eigenvalues = np.linspace(-(d - 1) * 10, -10, d - 1) def residuals(x): lam = np.append(x, 0.0) - return grad_log_C(lam) - eigenvalues_S + return grad_log_c(lam) - eigenvalues_S result = least_squares( residuals, - x0, + initial_eigenvalues, method="lm", ftol=1e-15, xtol=1e-10, From 5d985efc74cf84de98161d6920acb49cffb537e1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 12:08:17 +0000 Subject: [PATCH 3/4] Replace numpy calls with pyrecest.backend in ComplexBinghamDistribution and its tests Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/31c921c4-c991-44f5-89c7-ee929e5d4d07 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> --- .../complex_bingham_distribution.py | 201 +++++++++++------- .../test_complex_bingham_distribution.py | 138 ++++++++---- 2 files changed, 216 insertions(+), 123 deletions(-) diff --git a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py index 666030939..ca3c2e998 100644 --- a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py @@ -1,4 +1,4 @@ -# pylint: disable=no-name-in-module +# pylint: disable=no-name-in-module,no-member,redefined-builtin """Complex Bingham Distribution. Ported from the MATLAB libDirectional library: @@ -8,29 +8,61 @@ Kent, J. T. "The Complex Bingham Distribution and Shape Analysis." Journal of the Royal Statistical Society. Series B (Methodological), 1994, 285-299. """ -import numpy as np -from scipy.linalg import eigh from scipy.optimize import least_squares +from pyrecest.backend import ( + abs, + all, + allclose, + arange, + argsort, + array, + asarray, + complex128, + concatenate, + conj, + cumsum, + diff, + diag, + einsum, + empty, + exp, + linalg, + linspace, + log, + max, + maximum, + minimum, + pi, + prod, + random, + real, + sign, + sort, + sqrt, + sum, + zeros, +) + class ComplexBinghamDistribution: """Complex Bingham distribution on the complex unit hypersphere. The distribution is defined on the complex unit sphere - S^{2d-1} = {z ∈ C^d : ‖z‖ = 1} with pdf + S^{2d-1} = {z in C^d : ||z|| = 1} with pdf - p(z) ∝ exp(z^H B z), + p(z) proportional to exp(z^H B z), - where B is a d×d Hermitian parameter matrix. + where B is a d x d Hermitian parameter matrix. Attributes ---------- - B : numpy.ndarray, shape (d, d), dtype complex128 + B : array, shape (d, d), dtype complex128 Hermitian parameter matrix. dim : int Complex dimension d. log_norm_const : float - Negative log normalization constant (i.e. −log C(B) where C(B) is + Negative log normalization constant (i.e. -log C(B) where C(B) is the normalising constant), stored so that pdf(z) = exp(log_norm_const + Re(z^H B z)). """ @@ -43,8 +75,8 @@ def __init__(self, B): B : array_like, shape (d, d) Hermitian parameter matrix (B == B^H must hold). """ - B = np.asarray(B, dtype=complex) - assert np.allclose(B, B.conj().T, atol=1e-10), "B must be Hermitian" + B = asarray(B, dtype=complex128) + assert allclose(B, conj(B).T, atol=1e-10), "B must be Hermitian" self.B = B self.dim = B.shape[0] self.log_norm_const = ComplexBinghamDistribution.log_norm(B) @@ -64,17 +96,17 @@ def pdf(self, z): Returns ------- - numpy.ndarray, shape (n,) or float + array, shape (n,) or float Pdf value(s). """ - z = np.asarray(z, dtype=complex) + z = asarray(z, dtype=complex128) single = z.ndim == 1 if single: - z = z[:, np.newaxis] + z = z[:, None] # Re(z^H B z) for each column Bz = self.B @ z # (d, n) - vals = np.real(np.einsum("ij,ij->j", z.conj(), Bz)) # shape (n,) - p = np.exp(self.log_norm_const + vals) + vals = real(einsum("ij,ij->j", conj(z), Bz)) # shape (n,) + p = exp(self.log_norm_const + vals) return float(p[0]) if single else p def sample(self, n): @@ -82,7 +114,7 @@ def sample(self, n): Uses the rejection-sampling algorithm of Kent, Constable & Er (2004) "Simulation for the complex Bingham - distribution", *Statistics and Computing*, 14, 53-57. + distribution", Statistics and Computing, 14, 53-57. Parameters ---------- @@ -91,47 +123,47 @@ def sample(self, n): Returns ------- - numpy.ndarray, shape (d, n), dtype complex128 + array, shape (d, n), dtype complex128 Sampled unit vectors (each column has unit norm). """ d = self.dim if d < 2: raise ValueError("Sampling requires d >= 2.") - # Eigendecomposition of -B (eigenvectors V, eigenvalues Λ of -B) - eigenvalues_neg, V = eigh(-self.B) # sorted ascending by scipy + # Eigendecomposition of -B (eigenvectors V, eigenvalues Lambda of -B) + eigenvalues_neg, V = linalg.eigh(-self.B) # sorted ascending # Sort descending - idx = np.argsort(eigenvalues_neg)[::-1] + idx = argsort(eigenvalues_neg)[::-1] eigenvalues_neg = eigenvalues_neg[idx] V = V[:, idx] # Shift so the last (smallest) eigenvalue of -B becomes 0 - Lambda = eigenvalues_neg - eigenvalues_neg[-1] # Λ ≥ 0, Λ[-1] = 0 + Lambda = eigenvalues_neg - eigenvalues_neg[-1] # Lambda >= 0, Lambda[-1] = 0 # Precompute for truncated-exponential CDF inversion Lam = Lambda[:-1] # shape (d-1,) - samples = np.zeros((d, n), dtype=complex) + samples = zeros((d, n), dtype=complex128) for k in range(n): # Rejection loop while True: - S = np.zeros(d) - U = np.random.uniform(size=d - 1) + S = zeros(d) + U = random.uniform(size=(int(d - 1),)) for i in range(d - 1): if Lambda[i] < 0.03: # Nearly-uniform truncated exponential S[i] = U[i] else: - S[i] = -(1.0 / Lam[i]) * np.log( - 1.0 - U[i] * (1.0 - np.exp(-Lam[i])) + S[i] = -(1.0 / Lam[i]) * log( + 1.0 - U[i] * (1.0 - exp(-Lam[i])) ) - if S[:-1].sum() < 1.0: + if sum(S[:-1]) < 1.0: break - S[-1] = 1.0 - S[:-1].sum() + S[-1] = 1.0 - sum(S[:-1]) # Random phases - theta = 2.0 * np.pi * np.random.uniform(size=d) - weighted_phases = np.sqrt(S) * np.exp(1j * theta) + theta = 2.0 * pi * random.uniform(size=(int(d),)) + weighted_phases = sqrt(S) * exp(1j * theta) samples[:, k] = V @ weighted_phases return samples @@ -142,11 +174,11 @@ def sample(self, n): @staticmethod def log_norm(B): - """Compute the *negative* log normalization constant −log C(B). + """Compute the *negative* log normalization constant -log C(B). The formula used is (Kent 1994, eq. 2.3) after eigenvalue shift: - C(λ) = 2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j − λ_k) + C(lambda) = 2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k) For nearly-equal eigenvalues a small perturbation is applied so that the Vandermonde denominators remain well-conditioned, matching the @@ -162,11 +194,11 @@ def log_norm(B): float Negative log normalization constant. """ - B = np.asarray(B, dtype=complex) + B = asarray(B, dtype=complex128) d = B.shape[0] # Real eigenvalues of a Hermitian matrix - eigenvalues = np.linalg.eigvalsh(B) # sorted ascending + eigenvalues = linalg.eigvalsh(B) # sorted ascending # Shift so the maximum eigenvalue is 0 eigenvalue_shift = float(eigenvalues[-1]) @@ -175,18 +207,18 @@ def log_norm(B): # Perturb near-equal eigenvalues for numerical stability eigenvalues = ComplexBinghamDistribution._perturb_eigenvalues(eigenvalues) - if np.all(np.abs(eigenvalues) < 1e-3): - # All eigenvalues near zero: use limiting uniform value C = 2π^d / (d-1)! - log_C_shifted = np.log(2.0) + d * np.log(np.pi) - float( - np.sum(np.log(np.arange(1, d))) + if all(abs(eigenvalues) < 1e-3): + # All eigenvalues near zero: use limiting uniform value C = 2*pi^d / (d-1)! + log_C_shifted = log(2.0) + d * log(pi) - float( + sum(log(arange(1, d))) ) else: - # Analytical formula: C = 2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j - λ_k) + # Analytical formula: C = 2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k) log_C_shifted = ComplexBinghamDistribution._log_norm_from_eigenvalues( eigenvalues ) - # Apply eigenvalue-shift correction: C(λ) = exp(shift) · C(λ - shift) + # Apply eigenvalue-shift correction: C(lambda) = exp(shift) * C(lambda - shift) log_C = log_C_shifted + eigenvalue_shift return -log_C # inverted convention: log_norm_const = -log C @@ -195,27 +227,27 @@ def log_norm(B): def _log_norm_from_eigenvalues(eigenvalues): """Log normalization for shifted eigenvalues (max = 0) via partial fractions. - Computes log(2π^d · Σ_j exp(λ_j) / Π_{k≠j}(λ_j - λ_k)). + Computes log(2*pi^d * sum_j exp(lambda_j) / prod_{k!=j}(lambda_j - lambda_k)). """ d = len(eigenvalues) - log_prefix = np.log(2.0) + d * np.log(np.pi) + log_prefix = log(2.0) + d * log(pi) # For each j compute sign_j * exp(log_term_j) where - # log_term_j = λ_j - Σ_{k≠j} log|λ_j - λ_k| + # log_term_j = lambda_j - sum_{k!=j} log|lambda_j - lambda_k| # diff_matrix[j, k] = eigenvalues[j] - eigenvalues[k] for j != k - diff_matrix = eigenvalues[:, np.newaxis] - eigenvalues[np.newaxis, :] - log_terms = np.empty(d) - signs = np.empty(d) + diff_matrix = eigenvalues[:, None] - eigenvalues[None, :] + log_terms = empty(d) + signs = empty(d) for j in range(d): - mask = np.arange(d) != j + mask = arange(d) != j diffs = diff_matrix[j, mask] - signs[j] = np.prod(np.sign(diffs)) - log_terms[j] = eigenvalues[j] - np.sum(np.log(np.abs(diffs))) + signs[j] = prod(sign(diffs)) + log_terms[j] = eigenvalues[j] - sum(log(abs(diffs))) - # log(|Σ_j sign_j · exp(log_term_j)|) via log-sum-exp - max_log = np.max(log_terms) - scaled = np.sum(signs * np.exp(log_terms - max_log)) - log_sum = max_log + np.log(np.abs(scaled)) + # log(|sum_j sign_j * exp(log_term_j)|) via log-sum-exp + max_log = max(log_terms) + scaled = sum(signs * exp(log_terms - max_log)) + log_sum = max_log + log(abs(scaled)) return log_prefix + log_sum @@ -223,12 +255,12 @@ def _log_norm_from_eigenvalues(eigenvalues): def _perturb_eigenvalues(eigenvalues): """Sort eigenvalues descending and enforce minimum spacing of 0.01. - Mirrors MATLAB's ``makeSureEigenvaluesAreNotTooClose``. + Mirrors MATLAB's makeSureEigenvaluesAreNotTooClose. """ - lam = np.sort(eigenvalues)[::-1].copy() - diffs = np.diff(lam) # non-positive for sorted-descending - diffs = np.minimum(diffs, -0.01) # enforce gap ≥ 0.01 - lam[1:] = lam[0] + np.cumsum(diffs) + lam = sort(eigenvalues)[::-1] + diffs = diff(lam) # non-positive for sorted-descending + diffs = minimum(diffs, -0.01) # enforce gap >= 0.01 + lam[1:] = lam[0] + cumsum(diffs) return lam @classmethod @@ -238,15 +270,15 @@ def fit(cls, Z): Parameters ---------- Z : array_like, shape (d, n) - Complex unit vectors (columns); it is assumed that ‖z_i‖ = 1. + Complex unit vectors (columns); it is assumed that ||z_i|| = 1. Returns ------- ComplexBinghamDistribution """ - Z = np.asarray(Z, dtype=complex) + Z = asarray(Z, dtype=complex128) n = Z.shape[1] - S = Z @ Z.conj().T / n # sample scatter matrix + S = Z @ conj(Z).T / n # sample scatter matrix B = cls._estimate_parameter_matrix(S) return cls(B) @@ -257,45 +289,45 @@ def _estimate_parameter_matrix(S): The eigenvectors of B equal those of S. The eigenvalues of B are found by solving the moment-matching equations - ∂ log C(λ) / ∂λ_k = s_k, k = 1, …, d + d log C(lambda) / d lambda_k = s_k, k = 1, ..., d where s_k are the eigenvalues of S, using a least-squares solver with finite-difference gradients. Parameters ---------- - S : numpy.ndarray, shape (d, d), complex Hermitian + S : array, shape (d, d), complex Hermitian Sample scatter matrix (E[z z^H] estimate). Returns ------- - numpy.ndarray, shape (d, d), complex Hermitian + array, shape (d, d), complex Hermitian Estimated parameter matrix B. """ d = S.shape[0] - eigenvalues_S, V = eigh(S) # ascending + eigenvalues_S, V = linalg.eigh(S) # ascending def grad_log_c(lam): """Numerical gradient of log C via forward finite differences.""" eps = 1e-6 - B_diag = np.diag(lam.astype(complex)) + B_diag = diag(array(lam, dtype=complex128)) log_c0 = ComplexBinghamDistribution.log_norm(B_diag) - grad = np.empty(d) + grad = empty(d) for i in range(d): - lam_p = lam.copy() + lam_p = array(lam) lam_p[i] += eps log_cp = ComplexBinghamDistribution.log_norm( - np.diag(lam_p.astype(complex)) + diag(array(lam_p, dtype=complex128)) ) # log_norm_const = -log C, so d(log C)/dλ_i = -d(log_norm_const)/dλ_i grad[i] = (-log_cp - (-log_c0)) / eps return grad # Initial guess: spread eigenvalues below zero - initial_eigenvalues = np.linspace(-(d - 1) * 10, -10, d - 1) + initial_eigenvalues = linspace(-(d - 1) * 10, -10, int(d - 1)) def residuals(x): - lam = np.append(x, 0.0) + lam = concatenate([x, array([0.0])]) return grad_log_c(lam) - eigenvalues_S result = least_squares( @@ -306,37 +338,44 @@ def residuals(x): xtol=1e-10, max_nfev=int(1e4), ) - lam_B = np.append(result.x, 0.0) - B = V @ np.diag(lam_B.astype(complex)) @ V.conj().T - B = 0.5 * (B + B.conj().T) # enforce exact Hermitian symmetry + lam_B = concatenate([result.x, array([0.0])]) + B = V @ diag(array(lam_B, dtype=complex128)) @ conj(V).T + B = 0.5 * (B + conj(B).T) # enforce exact Hermitian symmetry return B @staticmethod def cauchy_schwarz_divergence(cB1, cB2): """Cauchy-Schwarz divergence between two complex Bingham distributions. - D_CS(p, q) = log C(B₁+B₂) − ½[log C(2B₁) + log C(2B₂)] + D_CS(p, q) = half * [log C(2*B1) + log C(2*B2)] - log C(B1+B2) >= 0 + + Using the stored negative log normalization (log_norm = -log C): + + D_CS = log_norm(B1+B2) - half * [log_norm(2*B1) + log_norm(2*B2)] + + This matches the MATLAB libDirectional CauchySchwarzDivergence convention. Parameters ---------- - cB1, cB2 : ComplexBinghamDistribution or numpy.ndarray - Distributions or parameter matrices. + cB1, cB2 : ComplexBinghamDistribution or array_like + Distributions or Hermitian parameter matrices. Returns ------- float + Non-negative divergence value. """ if isinstance(cB1, ComplexBinghamDistribution): B1 = cB1.B else: - B1 = np.asarray(cB1, dtype=complex) + B1 = asarray(cB1, dtype=complex128) if isinstance(cB2, ComplexBinghamDistribution): B2 = cB2.B else: - B2 = np.asarray(cB2, dtype=complex) + B2 = asarray(cB2, dtype=complex128) - assert np.allclose(B1, B1.conj().T, atol=1e-10), "B1 must be Hermitian" - assert np.allclose(B2, B2.conj().T, atol=1e-10), "B2 must be Hermitian" + assert allclose(B1, conj(B1).T, atol=1e-10), "B1 must be Hermitian" + assert allclose(B2, conj(B2).T, atol=1e-10), "B2 must be Hermitian" log_c1 = ComplexBinghamDistribution.log_norm(2.0 * B1) log_c2 = ComplexBinghamDistribution.log_norm(2.0 * B2) diff --git a/pyrecest/tests/distributions/test_complex_bingham_distribution.py b/pyrecest/tests/distributions/test_complex_bingham_distribution.py index 21dcc2bc3..dcb6cbf74 100644 --- a/pyrecest/tests/distributions/test_complex_bingham_distribution.py +++ b/pyrecest/tests/distributions/test_complex_bingham_distribution.py @@ -1,7 +1,22 @@ import unittest -import numpy as np import numpy.testing as npt +import pyrecest.backend + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import ( + array, + diag, + exp, + linalg, + log, + mean, + pi, + random, + real, + sort, + sqrt, +) from pyrecest.distributions import ComplexBinghamDistribution @@ -11,110 +26,149 @@ class TestComplexBinghamDistribution(unittest.TestCase): def setUp(self): # Simple 2x2 diagonal Hermitian B - self.B2 = np.diag([-3.0, 0.0]).astype(complex) + self.B2 = diag(array([-3.0, 0.0], dtype=complex)) self.cB2 = ComplexBinghamDistribution(self.B2) # 3x3 diagonal Hermitian B - self.B3 = np.diag([-5.0, -2.0, 0.0]).astype(complex) + self.B3 = diag(array([-5.0, -2.0, 0.0], dtype=complex)) self.cB3 = ComplexBinghamDistribution(self.B3) def test_constructor_hermitian_check(self): """Non-Hermitian matrix should raise AssertionError.""" with self.assertRaises(AssertionError): - ComplexBinghamDistribution(np.array([[1.0, 1j], [0.0, 1.0]])) + ComplexBinghamDistribution( + array([[1.0, 1j], [0.0, 1.0]]) + ) def test_log_norm_const_finite(self): """log_norm_const must be finite.""" - self.assertTrue(np.isfinite(self.cB2.log_norm_const)) - self.assertTrue(np.isfinite(self.cB3.log_norm_const)) + import math + self.assertTrue(math.isfinite(self.cB2.log_norm_const)) + self.assertTrue(math.isfinite(self.cB3.log_norm_const)) def test_dim(self): self.assertEqual(self.cB2.dim, 2) self.assertEqual(self.cB3.dim, 3) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_pdf_normalises_to_one_2d(self): """MC check: 2-D pdf integrates to 1 over S^3.""" - rng = np.random.default_rng(12345) - raw = rng.standard_normal((2, 200_000)) + 1j * rng.standard_normal((2, 200_000)) - Z = raw / np.linalg.norm(raw, axis=0, keepdims=True) - area = 2.0 * np.pi**2 # surface area of S^3 - mc_integral = np.mean(self.cB2.pdf(Z)) * area + random.seed(12345) + # Sample uniformly from S^3 using complex Gaussian projection + real_part = random.normal(size=(2, 200_000)) + imag_part = random.normal(size=(2, 200_000)) + raw = real_part + 1j * imag_part + norms = linalg.norm(raw, axis=0, keepdims=True) + Z = raw / norms + area = 2.0 * float(pi) ** 2 # surface area of S^3 + mc_integral = float(mean(self.cB2.pdf(Z))) * area npt.assert_almost_equal(mc_integral, 1.0, decimal=2) def test_pdf_positive(self): """pdf must return positive values.""" - z = np.array([1.0, 0.0], dtype=complex) + z = array([1.0, 0.0], dtype=complex) self.assertGreater(self.cB2.pdf(z), 0.0) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_pdf_batch_vs_single(self): """Vectorised pdf matches point-by-point evaluation.""" - rng = np.random.default_rng(0) - raw = rng.standard_normal((2, 10)) + 1j * rng.standard_normal((2, 10)) - Z = raw / np.linalg.norm(raw, axis=0, keepdims=True) + random.seed(0) + real_part = random.normal(size=(2, 10)) + imag_part = random.normal(size=(2, 10)) + raw = real_part + 1j * imag_part + Z = raw / linalg.norm(raw, axis=0, keepdims=True) batch = self.cB2.pdf(Z) - single = np.array([self.cB2.pdf(Z[:, k]) for k in range(10)]) + single = array([self.cB2.pdf(Z[:, k]) for k in range(10)]) npt.assert_allclose(batch, single, rtol=1e-10) def test_pdf_invariant_to_phase(self): """pdf(exp(i*alpha)*z) == pdf(z) for any global phase alpha.""" - z = np.array([0.6 + 0.3j, 0.0], dtype=complex) - z[1] = np.sqrt(1 - np.abs(z[0]) ** 2) + z = array([0.6 + 0.3j, 0.0], dtype=complex) + z[1] = sqrt(1 - abs(z[0]) ** 2) p0 = self.cB2.pdf(z) - for alpha in [0.3, 1.0, np.pi]: - p1 = self.cB2.pdf(np.exp(1j * alpha) * z) + for alpha in [0.3, 1.0, float(pi)]: + p1 = self.cB2.pdf(exp(1j * alpha) * z) npt.assert_almost_equal(p1, p0, decimal=10) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_sample_shape(self): """sample returns the right shape.""" - np.random.seed(42) + random.seed(42) S = self.cB2.sample(50) self.assertEqual(S.shape, (2, 50)) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_sample_unit_norm(self): """All samples must lie on the unit sphere.""" - np.random.seed(42) + random.seed(42) S = self.cB2.sample(100) - norms = np.linalg.norm(S, axis=0) - npt.assert_allclose(norms, np.ones(100), atol=1e-12) + norms = linalg.norm(S, axis=0) + npt.assert_allclose(norms, [1.0] * 100, atol=1e-12) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_sample_3d_unit_norm(self): """3-D samples also lie on the unit sphere.""" - np.random.seed(7) + random.seed(7) S = self.cB3.sample(50) - norms = np.linalg.norm(S, axis=0) - npt.assert_allclose(norms, np.ones(50), atol=1e-12) + norms = linalg.norm(S, axis=0) + npt.assert_allclose(norms, [1.0] * 50, atol=1e-12) def test_log_norm_2d_analytic(self): a = 3.0 - B = np.diag([-a, 0.0]).astype(complex) - log_C_expected = np.log(2 * np.pi**2 / a * (1 - np.exp(-a))) + B = diag(array([-a, 0.0], dtype=complex)) + log_C_expected = float(log(2 * pi**2 / a * (1 - exp(-a)))) log_norm_got = ComplexBinghamDistribution.log_norm(B) npt.assert_almost_equal(-log_norm_got, log_C_expected, decimal=6) def test_log_norm_equal_eigenvalues(self): """Equal eigenvalues (uniform) should not raise.""" - B = np.zeros((3, 3), dtype=complex) + from pyrecest.backend import zeros + B = zeros((3, 3), dtype=complex) log_norm = ComplexBinghamDistribution.log_norm(B) - self.assertTrue(np.isfinite(log_norm)) + import math + self.assertTrue(math.isfinite(log_norm)) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_fit_returns_instance(self): """fit() returns a ComplexBinghamDistribution instance.""" - np.random.seed(0) + random.seed(0) Z = self.cB2.sample(200) cB_fit = ComplexBinghamDistribution.fit(Z) self.assertIsInstance(cB_fit, ComplexBinghamDistribution) + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on JAX backend", + ) def test_fit_recovers_eigenvalue_gap(self): """Fit recovers the eigenvalue gap (up to additive const).""" - np.random.seed(0) - B = np.diag([-10.0, 0.0]).astype(complex) + random.seed(0) + B = diag(array([-10.0, 0.0], dtype=complex)) cB = ComplexBinghamDistribution(B) Z = cB.sample(2000) cB_fit = ComplexBinghamDistribution.fit(Z) - evals_fit = np.sort(np.real(np.linalg.eigvalsh(cB_fit.B))) - evals_true = np.sort(np.real(np.linalg.eigvalsh(B))) - gap_true = evals_true[-1] - evals_true[0] - gap_fit = evals_fit[-1] - evals_fit[0] + evals_fit = sort(real(linalg.eigvalsh(cB_fit.B))) + evals_true = sort(real(linalg.eigvalsh(B))) + gap_true = float(evals_true[-1] - evals_true[0]) + gap_fit = float(evals_fit[-1] - evals_fit[0]) npt.assert_almost_equal(gap_fit, gap_true, decimal=0) def test_cauchy_schwarz_zero_for_identical(self): @@ -124,16 +178,16 @@ def test_cauchy_schwarz_zero_for_identical(self): def test_cauchy_schwarz_symmetric(self): """D_CS(p, q) == D_CS(q, p).""" - B3a = np.diag([-5.0, -1.0, 0.0]).astype(complex) - B3b = np.diag([-3.0, -2.0, 0.0]).astype(complex) + B3a = diag(array([-5.0, -1.0, 0.0], dtype=complex)) + B3b = diag(array([-3.0, -2.0, 0.0], dtype=complex)) d_ab = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) d_ba = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3b, B3a) npt.assert_almost_equal(d_ab, d_ba, decimal=6) def test_cauchy_schwarz_nonneg(self): """Cauchy-Schwarz divergence must be >= 0.""" - B3a = np.diag([-5.0, -1.0, 0.0]).astype(complex) - B3b = np.diag([-3.0, -2.0, 0.0]).astype(complex) + B3a = diag(array([-5.0, -1.0, 0.0], dtype=complex)) + B3b = diag(array([-3.0, -2.0, 0.0], dtype=complex)) d = ComplexBinghamDistribution.cauchy_schwarz_divergence(B3a, B3b) self.assertGreaterEqual(d, -1e-10) From a5f140d7f62238bbe4b9249c48068d8d4e671bf2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 6 Apr 2026 12:10:05 +0000 Subject: [PATCH 4/4] Address code review: use .copy() for sort output, use ones() in tests Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/31c921c4-c991-44f5-89c7-ee929e5d4d07 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> --- .../hypersphere_subset/complex_bingham_distribution.py | 4 ++-- .../tests/distributions/test_complex_bingham_distribution.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py index ca3c2e998..7df549cf0 100644 --- a/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/complex_bingham_distribution.py @@ -257,7 +257,7 @@ def _perturb_eigenvalues(eigenvalues): Mirrors MATLAB's makeSureEigenvaluesAreNotTooClose. """ - lam = sort(eigenvalues)[::-1] + lam = sort(eigenvalues)[::-1].copy() diffs = diff(lam) # non-positive for sorted-descending diffs = minimum(diffs, -0.01) # enforce gap >= 0.01 lam[1:] = lam[0] + cumsum(diffs) @@ -314,7 +314,7 @@ def grad_log_c(lam): log_c0 = ComplexBinghamDistribution.log_norm(B_diag) grad = empty(d) for i in range(d): - lam_p = array(lam) + lam_p = lam.copy() lam_p[i] += eps log_cp = ComplexBinghamDistribution.log_norm( diag(array(lam_p, dtype=complex128)) diff --git a/pyrecest/tests/distributions/test_complex_bingham_distribution.py b/pyrecest/tests/distributions/test_complex_bingham_distribution.py index dcb6cbf74..096507b59 100644 --- a/pyrecest/tests/distributions/test_complex_bingham_distribution.py +++ b/pyrecest/tests/distributions/test_complex_bingham_distribution.py @@ -12,6 +12,7 @@ log, mean, pi, + ones, random, real, sort, @@ -115,7 +116,7 @@ def test_sample_unit_norm(self): random.seed(42) S = self.cB2.sample(100) norms = linalg.norm(S, axis=0) - npt.assert_allclose(norms, [1.0] * 100, atol=1e-12) + npt.assert_allclose(norms, ones(100), atol=1e-12) @unittest.skipIf( pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member @@ -126,7 +127,7 @@ def test_sample_3d_unit_norm(self): random.seed(7) S = self.cB3.sample(50) norms = linalg.norm(S, axis=0) - npt.assert_allclose(norms, [1.0] * 50, atol=1e-12) + npt.assert_allclose(norms, ones(50), atol=1e-12) def test_log_norm_2d_analytic(self): a = 3.0