Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyrecest/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@
AbstractSphericalHarmonicsDistribution,
)
from .hypersphere_subset.bingham_distribution import BinghamDistribution
from .hypersphere_subset.complex_angular_central_gaussian_distribution import (
ComplexAngularCentralGaussianDistribution,
)
from .hypersphere_subset.custom_hemispherical_distribution import (
CustomHemisphericalDistribution,
)
Expand Down Expand Up @@ -341,6 +344,7 @@
"AbstractSphericalDistribution",
"AbstractSphericalHarmonicsDistribution",
"BinghamDistribution",
"ComplexAngularCentralGaussianDistribution",
"CustomHemisphericalDistribution",
"CustomHyperhemisphericalDistribution",
"CustomHypersphericalDistribution",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# pylint: disable=redefined-builtin,no-name-in-module,no-member
import numpy as np

from pyrecest.backend import (
abs,
allclose,
array,
conj,
exp,
gammaln,
linalg,
log,
pi,
random,
real,
sqrt,
sum,
transpose,
)


class ComplexAngularCentralGaussianDistribution:
"""Complex Angular Central Gaussian distribution on the complex unit hypersphere.

This distribution is defined on the complex unit sphere in C^d (equivalently S^{2d-1}
in R^{2d}). It is parameterized by a Hermitian positive definite matrix C of shape (d, d).

Reference:
Ported from ComplexAngularCentralGaussian.m in libDirectional:
https://github.com/libDirectional/libDirectional/blob/master/lib/distributions/complexHypersphere/ComplexAngularCentralGaussian.m
"""

def __init__(self, C):
"""Initialize the distribution.

Parameters
----------
C : array-like of shape (d, d)
Hermitian positive definite parameter matrix.
"""
assert allclose(C, conj(transpose(C))), "C must be Hermitian"
self.C = C
self.dim = C.shape[0]

def pdf(self, za):
"""Evaluate the pdf at each row of za.

Parameters
----------
za : array-like of shape (n, d) or (d,)
Points on the complex unit sphere. Each row is a complex unit vector.

Returns
-------
p : array-like of shape (n,) or scalar
PDF values at each row of za.
"""
single = za.ndim == 1
if single:
za = za.reshape(1, -1)

# Solve C * X = za.T to get C^{-1} * za.T, shape (d, n)
C_inv_z = linalg.solve(self.C, transpose(za))
# Hermitian quadratic form: inner[i] = za[i]^H C^{-1} za[i]
inner = sum(conj(transpose(za)) * C_inv_z, axis=0) # shape (n,)

d = self.dim
# gamma(d) / (2 * pi^d) in log space: gammaln(d) - log(2) - d*log(pi)
log_normalizer = gammaln(array(float(d))) - log(2.0) - d * log(array(pi))
p = exp(log_normalizer) * abs(inner) ** (-d) / abs(linalg.det(self.C))

if single:
return p[0]
return p

def sample(self, n):
"""Sample n points from the distribution.

Parameters
----------
n : int
Number of samples.

Returns
-------
Z : array-like of shape (n, d)
Complex unit vectors sampled from the distribution.
"""
# Lower Cholesky factor: C = L @ L^H
L = linalg.cholesky(self.C)
a = random.normal(size=(n, self.dim))
b = random.normal(size=(n, self.dim))
# Each row of (a + 1j*b) is CN(0, 2*I); transform by L^T to get CN(0, 2*C)
# Using regular transpose (not conjugate) so each row maps as z -> L @ z (column form)
z = (a + 1j * b) @ transpose(L)
norms = sqrt(real(sum(z * conj(z), axis=-1)))
return z / norms.reshape(-1, 1)

@staticmethod
def fit(Z, n_iterations=100):
"""Fit distribution to data Z using fixed-point iterations.

Parameters
----------
Z : array-like of shape (n, d)
Complex unit vectors (each row is a sample).
n_iterations : int, optional
Number of fixed-point iterations (default 100).

Returns
-------
dist : ComplexAngularCentralGaussianDistribution
"""
C = ComplexAngularCentralGaussianDistribution.estimate_parameter_matrix(
Z, n_iterations
)
return ComplexAngularCentralGaussianDistribution(C)

@staticmethod
def estimate_parameter_matrix(Z, n_iterations=100):
"""Estimate the parameter matrix from data using fixed-point iterations.

Parameters
----------
Z : array-like of shape (n, d)
Complex unit vectors (each row is a sample).
n_iterations : int, optional
Number of iterations (default 100).

Returns
-------
C : array-like of shape (d, d)
Estimated Hermitian parameter matrix.
"""
N = Z.shape[0]
D = Z.shape[1]
C = array(np.eye(D, dtype=complex))

for _ in range(n_iterations):
# Solve C * X = Z.T to get C^{-1} * Z.T, shape (d, n)
C_inv_Z = linalg.solve(C, transpose(Z))
# Hermitian quadratic forms: inner[k] = Z[k]^H C^{-1} Z[k]
inner = sum(conj(transpose(Z)) * C_inv_Z, axis=0) # shape (n,)
weights = (D - 1) / abs(inner) # shape (n,)
# C = (1/N) * sum_k weights[k] * z_k z_k^H
# = Z.T @ diag(weights) @ conj(Z) / N
C = transpose(Z) @ (weights.reshape(-1, 1) * conj(Z)) / N

return C
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import unittest

import numpy as np
import numpy.testing as npt
import pyrecest.backend

# pylint: disable=no-name-in-module,no-member
from pyrecest.backend import array
from pyrecest.distributions import ComplexAngularCentralGaussianDistribution


class TestComplexAngularCentralGaussianDistribution(unittest.TestCase):
def setUp(self):
"""Set up test fixtures."""
# Identity matrix case (uniform distribution on complex unit sphere)
self.C_identity_2d = array(np.eye(2, dtype=complex))
self.dist_identity_2d = ComplexAngularCentralGaussianDistribution(
self.C_identity_2d
)

# Non-trivial Hermitian positive definite matrix for 2D case
# C = [[2, 1+1j], [1-1j, 3]]
C_vals = np.array([[2.0, 1.0 + 1.0j], [1.0 - 1.0j, 3.0]])
self.C_nontrivial_2d = array(C_vals)
self.dist_nontrivial_2d = ComplexAngularCentralGaussianDistribution(
self.C_nontrivial_2d
)

def test_constructor_valid(self):
"""Test that constructor accepts a Hermitian matrix."""
self.assertEqual(self.dist_identity_2d.dim, 2)
self.assertEqual(self.dist_nontrivial_2d.dim, 2)

def test_constructor_non_hermitian_raises(self):
"""Test that constructor rejects a non-Hermitian matrix."""
C_bad = array(np.array([[1.0, 2.0 + 1.0j], [0.0, 1.0]]))
with self.assertRaises(AssertionError):
ComplexAngularCentralGaussianDistribution(C_bad)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_pdf_identity_uniform(self):
"""For C=I, the pdf should be constant gamma(d)/(2*pi^d) on the unit sphere."""
d = 2
# Expected: gamma(2) / (2*pi^2) = 1 / (2*pi^2)
expected = 1.0 / (2.0 * np.pi**d) # gamma(2)=1

# Test on several unit vectors
z1 = array(np.array([[1.0, 0.0]], dtype=complex))
z2 = array(np.array([[0.0, 1.0]], dtype=complex))
z3 = array((np.array([1.0, 1.0j]) / np.sqrt(2.0)).reshape(1, -1))
z4 = array((np.array([1.0 + 1.0j, 1.0 - 1.0j]) / 2.0).reshape(1, -1))

for z in [z1, z2, z3, z4]:
p = self.dist_identity_2d.pdf(z)
npt.assert_allclose(
float(np.real(np.array(p[0]))),
expected,
rtol=1e-6,
err_msg=f"PDF for identity C is not constant at {z}",
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_pdf_positive(self):
"""PDF values should be positive for any unit vector."""
z = array(np.array([[1.0 / np.sqrt(2.0), 1.0j / np.sqrt(2.0)]]))
p = self.dist_nontrivial_2d.pdf(z)
self.assertGreater(float(np.real(np.array(p[0]))), 0.0)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_pdf_batch_vs_single(self):
"""Batch PDF evaluation should match individual evaluations."""
zs = np.array(
[
[1.0, 0.0],
[0.0, 1.0],
[1.0 / np.sqrt(2.0), 1.0j / np.sqrt(2.0)],
],
dtype=complex,
)
za = array(zs)

p_batch = self.dist_nontrivial_2d.pdf(za)
for i, z in enumerate(zs):
p_single = self.dist_nontrivial_2d.pdf(array(z.reshape(1, -1)))
npt.assert_allclose(
float(np.real(np.array(p_batch[i]))),
float(np.real(np.array(p_single[0]))),
rtol=1e-10,
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_sample_unit_norm(self):
"""Sampled vectors should lie on the complex unit sphere."""
n = 100
Z = self.dist_nontrivial_2d.sample(n)
Z_np = np.array(Z)
norms_sq = np.array(
[np.real(np.sum(Z_np[k] * np.conj(Z_np[k]))) for k in range(n)]
)
npt.assert_allclose(norms_sq, np.ones(n), atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_sample_correct_dim(self):
"""Sampled vectors should have the correct shape."""
n = 50
Z = self.dist_identity_2d.sample(n)
self.assertEqual(Z.shape[0], n)
self.assertEqual(Z.shape[1], 2)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_estimate_parameter_matrix_identity(self):
"""Fitting samples from identity-C distribution should recover approx identity."""
pyrecest.backend.random.seed(42) # pylint: disable=no-member
n = 2000
Z = self.dist_identity_2d.sample(n)
C_est = ComplexAngularCentralGaussianDistribution.estimate_parameter_matrix(
Z, n_iterations=100
)
# Normalize C_est to have trace equal to 2 (matching identity)
C_est_np = np.array(C_est)
C_est_normalized = C_est_np / np.trace(C_est_np).real * 2.0
npt.assert_allclose(
np.real(C_est_normalized),
np.eye(2),
atol=0.15,
err_msg="Estimated C does not approximately match identity",
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_fit_returns_distribution(self):
"""fit() should return a ComplexAngularCentralGaussianDistribution."""
pyrecest.backend.random.seed(0) # pylint: disable=no-member
Z = self.dist_identity_2d.sample(50)
dist = ComplexAngularCentralGaussianDistribution.fit(Z, n_iterations=10)
self.assertIsInstance(dist, ComplexAngularCentralGaussianDistribution)
self.assertEqual(dist.dim, 2)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax",
reason="Not supported on JAX backend",
) # pylint: disable=no-member
def test_3d_case(self):
"""Test basic functionality for d=3."""
C_3d = array(np.eye(3, dtype=complex))
dist = ComplexAngularCentralGaussianDistribution(C_3d)
self.assertEqual(dist.dim, 3)

Z = dist.sample(20)
self.assertEqual(Z.shape, (20, 3))

# Check unit norms
Z_np = np.array(Z)
norms_sq = np.array(
[np.real(np.sum(Z_np[k] * np.conj(Z_np[k]))) for k in range(20)]
)
npt.assert_allclose(norms_sq, np.ones(20), atol=1e-10)

# For d=3, C=I: pdf should be gamma(3)/(2*pi^3) = 2/(2*pi^3) = 1/pi^3
z_test = array(np.array([[1.0, 0.0, 0.0]], dtype=complex))
p = dist.pdf(z_test)
expected = 1.0 / np.pi**3 # gamma(3)=2, so 2/(2*pi^3)=1/pi^3
npt.assert_allclose(float(np.real(np.array(p[0]))), expected, rtol=1e-6)


if __name__ == "__main__":
unittest.main()
Loading