diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 818450ca5..7f457bb81 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -154,6 +154,9 @@ AbstractSphericalHarmonicsDistribution, ) from .hypersphere_subset.bingham_distribution import BinghamDistribution +from .hypersphere_subset.complex_angular_central_gaussian_distribution import ( + ComplexAngularCentralGaussianDistribution, +) from .hypersphere_subset.custom_hemispherical_distribution import ( CustomHemisphericalDistribution, ) @@ -341,6 +344,7 @@ "AbstractSphericalDistribution", "AbstractSphericalHarmonicsDistribution", "BinghamDistribution", + "ComplexAngularCentralGaussianDistribution", "CustomHemisphericalDistribution", "CustomHyperhemisphericalDistribution", "CustomHypersphericalDistribution", diff --git a/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py new file mode 100644 index 000000000..b0bec33c6 --- /dev/null +++ b/pyrecest/distributions/hypersphere_subset/complex_angular_central_gaussian_distribution.py @@ -0,0 +1,149 @@ +# pylint: disable=redefined-builtin,no-name-in-module,no-member +import numpy as np + +from pyrecest.backend import ( + abs, + allclose, + array, + conj, + exp, + gammaln, + linalg, + log, + pi, + random, + real, + sqrt, + sum, + transpose, +) + + +class ComplexAngularCentralGaussianDistribution: + """Complex Angular Central Gaussian distribution on the complex unit hypersphere. + + This distribution is defined on the complex unit sphere in C^d (equivalently S^{2d-1} + in R^{2d}). It is parameterized by a Hermitian positive definite matrix C of shape (d, d). + + Reference: + Ported from ComplexAngularCentralGaussian.m in libDirectional: + https://github.com/libDirectional/libDirectional/blob/master/lib/distributions/complexHypersphere/ComplexAngularCentralGaussian.m + """ + + def __init__(self, C): + """Initialize the distribution. + + Parameters + ---------- + C : array-like of shape (d, d) + Hermitian positive definite parameter matrix. + """ + assert allclose(C, conj(transpose(C))), "C must be Hermitian" + self.C = C + self.dim = C.shape[0] + + def pdf(self, za): + """Evaluate the pdf at each row of za. + + Parameters + ---------- + za : array-like of shape (n, d) or (d,) + Points on the complex unit sphere. Each row is a complex unit vector. + + Returns + ------- + p : array-like of shape (n,) or scalar + PDF values at each row of za. + """ + single = za.ndim == 1 + if single: + za = za.reshape(1, -1) + + # Solve C * X = za.T to get C^{-1} * za.T, shape (d, n) + C_inv_z = linalg.solve(self.C, transpose(za)) + # Hermitian quadratic form: inner[i] = za[i]^H C^{-1} za[i] + inner = sum(conj(transpose(za)) * C_inv_z, axis=0) # shape (n,) + + d = self.dim + # gamma(d) / (2 * pi^d) in log space: gammaln(d) - log(2) - d*log(pi) + log_normalizer = gammaln(array(float(d))) - log(2.0) - d * log(array(pi)) + p = exp(log_normalizer) * abs(inner) ** (-d) / abs(linalg.det(self.C)) + + if single: + return p[0] + return p + + def sample(self, n): + """Sample n points from the distribution. + + Parameters + ---------- + n : int + Number of samples. + + Returns + ------- + Z : array-like of shape (n, d) + Complex unit vectors sampled from the distribution. + """ + # Lower Cholesky factor: C = L @ L^H + L = linalg.cholesky(self.C) + a = random.normal(size=(n, self.dim)) + b = random.normal(size=(n, self.dim)) + # Each row of (a + 1j*b) is CN(0, 2*I); transform by L^T to get CN(0, 2*C) + # Using regular transpose (not conjugate) so each row maps as z -> L @ z (column form) + z = (a + 1j * b) @ transpose(L) + norms = sqrt(real(sum(z * conj(z), axis=-1))) + return z / norms.reshape(-1, 1) + + @staticmethod + def fit(Z, n_iterations=100): + """Fit distribution to data Z using fixed-point iterations. + + Parameters + ---------- + Z : array-like of shape (n, d) + Complex unit vectors (each row is a sample). + n_iterations : int, optional + Number of fixed-point iterations (default 100). + + Returns + ------- + dist : ComplexAngularCentralGaussianDistribution + """ + C = ComplexAngularCentralGaussianDistribution.estimate_parameter_matrix( + Z, n_iterations + ) + return ComplexAngularCentralGaussianDistribution(C) + + @staticmethod + def estimate_parameter_matrix(Z, n_iterations=100): + """Estimate the parameter matrix from data using fixed-point iterations. + + Parameters + ---------- + Z : array-like of shape (n, d) + Complex unit vectors (each row is a sample). + n_iterations : int, optional + Number of iterations (default 100). + + Returns + ------- + C : array-like of shape (d, d) + Estimated Hermitian parameter matrix. + """ + N = Z.shape[0] + D = Z.shape[1] + C = array(np.eye(D, dtype=complex)) + + for _ in range(n_iterations): + # Solve C * X = Z.T to get C^{-1} * Z.T, shape (d, n) + C_inv_Z = linalg.solve(C, transpose(Z)) + # Hermitian quadratic forms: inner[k] = Z[k]^H C^{-1} Z[k] + inner = sum(conj(transpose(Z)) * C_inv_Z, axis=0) # shape (n,) + weights = (D - 1) / abs(inner) # shape (n,) + # C = (1/N) * sum_k weights[k] * z_k z_k^H + # = Z.T @ diag(weights) @ conj(Z) / N + C = transpose(Z) @ (weights.reshape(-1, 1) * conj(Z)) / N + + return C diff --git a/pyrecest/tests/distributions/test_complex_angular_central_gaussian_distribution.py b/pyrecest/tests/distributions/test_complex_angular_central_gaussian_distribution.py new file mode 100644 index 000000000..772d9a84b --- /dev/null +++ b/pyrecest/tests/distributions/test_complex_angular_central_gaussian_distribution.py @@ -0,0 +1,187 @@ +import unittest + +import numpy as np +import numpy.testing as npt +import pyrecest.backend + +# pylint: disable=no-name-in-module,no-member +from pyrecest.backend import array +from pyrecest.distributions import ComplexAngularCentralGaussianDistribution + + +class TestComplexAngularCentralGaussianDistribution(unittest.TestCase): + def setUp(self): + """Set up test fixtures.""" + # Identity matrix case (uniform distribution on complex unit sphere) + self.C_identity_2d = array(np.eye(2, dtype=complex)) + self.dist_identity_2d = ComplexAngularCentralGaussianDistribution( + self.C_identity_2d + ) + + # Non-trivial Hermitian positive definite matrix for 2D case + # C = [[2, 1+1j], [1-1j, 3]] + C_vals = np.array([[2.0, 1.0 + 1.0j], [1.0 - 1.0j, 3.0]]) + self.C_nontrivial_2d = array(C_vals) + self.dist_nontrivial_2d = ComplexAngularCentralGaussianDistribution( + self.C_nontrivial_2d + ) + + def test_constructor_valid(self): + """Test that constructor accepts a Hermitian matrix.""" + self.assertEqual(self.dist_identity_2d.dim, 2) + self.assertEqual(self.dist_nontrivial_2d.dim, 2) + + def test_constructor_non_hermitian_raises(self): + """Test that constructor rejects a non-Hermitian matrix.""" + C_bad = array(np.array([[1.0, 2.0 + 1.0j], [0.0, 1.0]])) + with self.assertRaises(AssertionError): + ComplexAngularCentralGaussianDistribution(C_bad) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_pdf_identity_uniform(self): + """For C=I, the pdf should be constant gamma(d)/(2*pi^d) on the unit sphere.""" + d = 2 + # Expected: gamma(2) / (2*pi^2) = 1 / (2*pi^2) + expected = 1.0 / (2.0 * np.pi**d) # gamma(2)=1 + + # Test on several unit vectors + z1 = array(np.array([[1.0, 0.0]], dtype=complex)) + z2 = array(np.array([[0.0, 1.0]], dtype=complex)) + z3 = array((np.array([1.0, 1.0j]) / np.sqrt(2.0)).reshape(1, -1)) + z4 = array((np.array([1.0 + 1.0j, 1.0 - 1.0j]) / 2.0).reshape(1, -1)) + + for z in [z1, z2, z3, z4]: + p = self.dist_identity_2d.pdf(z) + npt.assert_allclose( + float(np.real(np.array(p[0]))), + expected, + rtol=1e-6, + err_msg=f"PDF for identity C is not constant at {z}", + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_pdf_positive(self): + """PDF values should be positive for any unit vector.""" + z = array(np.array([[1.0 / np.sqrt(2.0), 1.0j / np.sqrt(2.0)]])) + p = self.dist_nontrivial_2d.pdf(z) + self.assertGreater(float(np.real(np.array(p[0]))), 0.0) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_pdf_batch_vs_single(self): + """Batch PDF evaluation should match individual evaluations.""" + zs = np.array( + [ + [1.0, 0.0], + [0.0, 1.0], + [1.0 / np.sqrt(2.0), 1.0j / np.sqrt(2.0)], + ], + dtype=complex, + ) + za = array(zs) + + p_batch = self.dist_nontrivial_2d.pdf(za) + for i, z in enumerate(zs): + p_single = self.dist_nontrivial_2d.pdf(array(z.reshape(1, -1))) + npt.assert_allclose( + float(np.real(np.array(p_batch[i]))), + float(np.real(np.array(p_single[0]))), + rtol=1e-10, + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_sample_unit_norm(self): + """Sampled vectors should lie on the complex unit sphere.""" + n = 100 + Z = self.dist_nontrivial_2d.sample(n) + Z_np = np.array(Z) + norms_sq = np.array( + [np.real(np.sum(Z_np[k] * np.conj(Z_np[k]))) for k in range(n)] + ) + npt.assert_allclose(norms_sq, np.ones(n), atol=1e-10) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_sample_correct_dim(self): + """Sampled vectors should have the correct shape.""" + n = 50 + Z = self.dist_identity_2d.sample(n) + self.assertEqual(Z.shape[0], n) + self.assertEqual(Z.shape[1], 2) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_estimate_parameter_matrix_identity(self): + """Fitting samples from identity-C distribution should recover approx identity.""" + pyrecest.backend.random.seed(42) # pylint: disable=no-member + n = 2000 + Z = self.dist_identity_2d.sample(n) + C_est = ComplexAngularCentralGaussianDistribution.estimate_parameter_matrix( + Z, n_iterations=100 + ) + # Normalize C_est to have trace equal to 2 (matching identity) + C_est_np = np.array(C_est) + C_est_normalized = C_est_np / np.trace(C_est_np).real * 2.0 + npt.assert_allclose( + np.real(C_est_normalized), + np.eye(2), + atol=0.15, + err_msg="Estimated C does not approximately match identity", + ) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_fit_returns_distribution(self): + """fit() should return a ComplexAngularCentralGaussianDistribution.""" + pyrecest.backend.random.seed(0) # pylint: disable=no-member + Z = self.dist_identity_2d.sample(50) + dist = ComplexAngularCentralGaussianDistribution.fit(Z, n_iterations=10) + self.assertIsInstance(dist, ComplexAngularCentralGaussianDistribution) + self.assertEqual(dist.dim, 2) + + @unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", + reason="Not supported on JAX backend", + ) # pylint: disable=no-member + def test_3d_case(self): + """Test basic functionality for d=3.""" + C_3d = array(np.eye(3, dtype=complex)) + dist = ComplexAngularCentralGaussianDistribution(C_3d) + self.assertEqual(dist.dim, 3) + + Z = dist.sample(20) + self.assertEqual(Z.shape, (20, 3)) + + # Check unit norms + Z_np = np.array(Z) + norms_sq = np.array( + [np.real(np.sum(Z_np[k] * np.conj(Z_np[k]))) for k in range(20)] + ) + npt.assert_allclose(norms_sq, np.ones(20), atol=1e-10) + + # For d=3, C=I: pdf should be gamma(3)/(2*pi^3) = 2/(2*pi^3) = 1/pi^3 + z_test = array(np.array([[1.0, 0.0, 0.0]], dtype=complex)) + p = dist.pdf(z_test) + expected = 1.0 / np.pi**3 # gamma(3)=2, so 2/(2*pi^3)=1/pi^3 + npt.assert_allclose(float(np.real(np.array(p[0]))), expected, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main()