diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 818450ca5..2c4bbf8f7 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -200,6 +200,10 @@ ) from .hypersphere_subset.von_mises_fisher_distribution import VonMisesFisherDistribution from .hypersphere_subset.watson_distribution import WatsonDistribution +from .hypersphere_subset.bayesian_complex_watson_mixture_model import ( + BayesianComplexWatsonMixtureModel, +) +from .hypersphere_subset.complex_watson_distribution import ComplexWatsonDistribution from .hypertorus.abstract_hypertoroidal_distribution import ( AbstractHypertoroidalDistribution, ) @@ -359,6 +363,8 @@ "SphericalHarmonicsDistributionReal", "VonMisesFisherDistribution", "WatsonDistribution", + "BayesianComplexWatsonMixtureModel", + "ComplexWatsonDistribution", "AbstractHypertoroidalDistribution", "AbstractToroidalBivarVMDistribution", "AbstractToroidalDistribution", diff --git a/pyrecest/distributions/hypersphere_subset/bayesian_complex_watson_mixture_model.py b/pyrecest/distributions/hypersphere_subset/bayesian_complex_watson_mixture_model.py new file mode 100644 index 000000000..5d27a2560 --- /dev/null +++ b/pyrecest/distributions/hypersphere_subset/bayesian_complex_watson_mixture_model.py @@ -0,0 +1,398 @@ +""" +Bayesian Complex Watson Mixture Model. + +Extends the complex Watson mixture model by adding: +- A complex Bingham prior for the mode vectors. +- A Dirichlet prior for the mixture weights. + +Fitting is performed via a variational EM algorithm. + +Reference: + Derived from libDirectional (MATLAB): + https://github.com/libDirectional/libDirectional +""" + +import numpy as np +from scipy.special import digamma + +from .complex_watson_distribution import ComplexWatsonDistribution + + +class BayesianComplexWatsonMixtureModel: + """ + Bayesian complex Watson mixture model. + + Stores the posterior parameters (complex Bingham matrices B, concentration + parameters, and Dirichlet parameter alpha) resulting from fitting the model + to observations. + + Attributes + ---------- + B : ndarray of shape (D, D, K), complex + Complex Bingham parameter matrices for each component. + concentrations : ndarray of shape (K,) + Concentration parameters (kappa) for each Watson component. + alpha : ndarray of shape (K,) + Dirichlet parameter vector (proportional to mixture weights). + K : int + Number of mixture components. + dim : int + Dimension D of the complex space C^D. + """ + + def __init__(self, B, concentrations, alpha): + """ + Construct a BayesianComplexWatsonMixtureModel from posterior parameters. + + Args: + B: (D, D, K) complex array of Hermitian Bingham parameter matrices. + concentrations: (K,) array of concentration parameters. + alpha: (K,) Dirichlet parameter vector. + """ + B = np.asarray(B, dtype=complex) + concentrations = np.asarray(concentrations, dtype=float).ravel() + alpha = np.asarray(alpha, dtype=float).ravel() + + K = alpha.shape[0] + assert B.shape[2] == K, "B.shape[2] must equal len(alpha)" + assert concentrations.shape[0] == K, "len(concentrations) must equal len(alpha)" + + for k in range(K): + assert np.allclose( + B[:, :, k], B[:, :, k].conj().T, atol=1e-6 + ), f"B[:,:,{k}] must be Hermitian" + + self.B = B + self.concentrations = concentrations + self.alpha = alpha + self.K = K + self.dim = B.shape[0] + + @staticmethod + def fit(Z, parameters): + """ + Fit the model to observations Z using the given hyperparameters. + + Args: + Z: (D, N) complex matrix of unit-vector observations. + parameters (dict): Hyperparameter dict. Must contain: + - ``initial``: dict with keys ``B`` (D,D,K complex), ``alpha`` (K,), + and ``kappa`` (scalar or (K,)). + - ``prior``: dict with keys ``B`` (D,D,K complex) and ``alpha`` (K,). + - ``I`` (int): number of EM iterations. + Optionally: + - ``uniformComponent`` (bool): if True, last component is uniform. + - ``prior['saliencies']`` (float or (N,)): observation saliencies. + + Returns: + tuple: (BayesianComplexWatsonMixtureModel, posterior dict) + """ + posterior = BayesianComplexWatsonMixtureModel.estimate_posterior(Z, parameters) + model = BayesianComplexWatsonMixtureModel( + posterior["B"], posterior["kappa"], posterior["alpha"] + ) + return model, posterior + + @staticmethod + def fit_default(Z, K): + """ + Fit the model with default hyperparameters. + + Args: + Z: (D, N) complex matrix of unit-vector observations. + The feature dimension D must be less than 100. + K (int): Number of mixture components. + + Returns: + tuple: (BayesianComplexWatsonMixtureModel, posterior dict) + """ + assert Z.shape[0] < 100, ( + "fit_default assumes D < 100 (feature dimension, not sample count)" + ) + D = Z.shape[0] + parameters = BayesianComplexWatsonMixtureModel.parameters_default(D, K) + return BayesianComplexWatsonMixtureModel.fit(Z, parameters) + + @staticmethod + def parameters_default(D, K): + """ + Build the default hyperparameter dict for dimension D and K components. + + Args: + D (int): Dimension of the complex space. + K (int): Number of mixture components. + + Returns: + dict: Default parameter dict compatible with ``fit`` and + ``estimate_posterior``. + """ + parameters = {} + parameters["initial"] = { + "B": np.zeros((D, D, K), dtype=complex), + "kappa": 20.0, + "alpha": 1.0 / K + np.linspace(-0.14 / K, 0.14 / K, K), + } + parameters["prior"] = { + "B": np.zeros((D, D, K), dtype=complex), + "alpha": np.ones(K) / K, + "saliencies": 1.0, + } + parameters["I"] = 40 + parameters["uniformComponent"] = False + return parameters + + @staticmethod + def estimate_posterior(Z, parameters): + """ + Run the variational EM algorithm to estimate posterior parameters. + + E-step: update soft assignments gamma[n, k] using the current parameters. + M-step: update B, alpha, and kappa from the weighted sufficient statistics. + + Args: + Z: (D, N) complex matrix of unit-vector observations. + parameters (dict): Hyperparameter dict as returned by + ``parameters_default``. + + Returns: + dict: Posterior with keys ``B``, ``kappa``, ``alpha``, ``gamma``. + """ + uniform_component = parameters.get("uniformComponent", False) + + assert "initial" in parameters + assert "B" in parameters["initial"] + assert "alpha" in parameters["initial"] + assert "kappa" in parameters["initial"] + assert "prior" in parameters + assert "B" in parameters["prior"] + assert "alpha" in parameters["prior"] + assert "I" in parameters + + Z = np.asarray(Z, dtype=complex) + D, N = Z.shape + K = len(np.asarray(parameters["initial"]["alpha"]).ravel()) + + B_init = np.asarray(parameters["initial"]["B"], dtype=complex) + for k in range(K): + assert np.allclose( + B_init[:, :, k], B_init[:, :, k].conj().T, atol=1e-6 + ), "initial B must be Hermitian" + + kappa_init = parameters["initial"]["kappa"] + if np.ndim(kappa_init) == 0: + kappa_init_arr = np.full(K, float(kappa_init)) + else: + kappa_init_arr = np.asarray(kappa_init, dtype=float).ravel().copy() + + posterior = { + "B": B_init.copy(), + "alpha": np.asarray( + parameters["initial"]["alpha"], dtype=float + ).ravel().copy(), + "kappa": kappa_init_arr, + "gamma": np.zeros((N, K)), + } + + # Precompute outer products Z[:, n] * conj(Z)[:, n] reshaped to (D*D, N) + ZZ = ( + Z[:, np.newaxis, :] * Z.conj()[np.newaxis, :, :] + ).reshape(D * D, N) + + # Log saliencies shape (N, K) + saliencies = parameters["prior"].get("saliencies", 1.0) + saliencies_arr = np.asarray(saliencies) + if saliencies_arr.ndim == 0: + ln_saliencies = np.full((N, K), np.log(max(float(saliencies_arr), 1e-7))) + else: + saliencies_vec = saliencies_arr.ravel() + assert len(saliencies_vec) == N + ln_saliencies = ( + np.log(np.maximum(saliencies_vec, 1e-7))[:, np.newaxis] + * np.ones((N, K)) + ) + + prior_B = np.asarray(parameters["prior"]["B"], dtype=complex) + prior_alpha = np.asarray(parameters["prior"]["alpha"], dtype=float).ravel() + concentration_max = 500.0 + + for _ in range(parameters["I"]): + # E-step + log_gamma = ln_saliencies.copy() + + quad = BayesianComplexWatsonMixtureModel.quadratic_expectation( + ZZ.reshape(D, D, N), posterior["B"] + ) + log_gamma += posterior["kappa"][np.newaxis, :] * quad + log_gamma += ComplexWatsonDistribution.log_norm(D, posterior["kappa"])[ + np.newaxis, : + ] + log_gamma += digamma(posterior["alpha"])[np.newaxis, :] + + log_gamma -= log_gamma.max(axis=1, keepdims=True) + gamma = np.exp(log_gamma) + gamma /= gamma.sum(axis=1, keepdims=True) + + assert not np.any(np.isnan(gamma)), "NaN in gamma during E-step" + posterior["gamma"] = gamma + + # M-step + N_k = gamma.sum(axis=0) + + posterior["alpha"] = prior_alpha + N_k + + cov_matrix = (ZZ @ gamma) / np.maximum(N_k[np.newaxis, :], 1e-300) + cov_matrix = cov_matrix.reshape(D, D, K) + + for k in range(K): + posterior["B"][:, :, k] = ( + posterior["kappa"][k] * N_k[k] * cov_matrix[:, :, k] + + prior_B[:, :, k] + ) + posterior["B"][:, :, k] = 0.5 * ( + posterior["B"][:, :, k] + posterior["B"][:, :, k].conj().T + ) + + for k in range(K): + if uniform_component and k == K - 1: + posterior["kappa"][k] = 0.0 + continue + + cov_k = cov_matrix[:, :, k].reshape(D, D, 1) + Bk = posterior["B"][:, :, k].reshape(D, D, 1) + quad_k = float( + np.real( + BayesianComplexWatsonMixtureModel.quadratic_expectation( + cov_k, Bk + )[0, 0] + ) + ) + posterior["kappa"][k] = ( + ComplexWatsonDistribution._hypergeometric_ratio_inverse( + quad_k, D, concentration_max=concentration_max + ) + ) + + return posterior + + @staticmethod + def quadratic_expectation(dyadic_products, B): + """ + Compute E_{z ~ cBingham(B_k)}[z^H * A * z] for each A in dyadic_products. + + Approximation: + - Large eigenvalue regime (any eigenvalue > 1): use first-order moments of + the complex Bingham, computed by numerical differentiation. + - Otherwise: assume uniform (E[zz^H] = I/D). + + Args: + dyadic_products: (D, D, N) complex array of D x D matrices A_n. + B: (D, D, K) complex Hermitian Bingham parameters, or (D, D) for K=1. + + Returns: + ndarray of shape (N, K): real values E[z^H A_n z]. + """ + dyadic_products = np.asarray(dyadic_products, dtype=complex) + B = np.asarray(B, dtype=complex) + + if B.ndim == 2: + B = B[:, :, np.newaxis] + + D = B.shape[0] + N = dyadic_products.shape[2] if dyadic_products.ndim == 3 else 1 + K = B.shape[2] + + if dyadic_products.ndim == 2: + dyadic_products = dyadic_products[:, :, np.newaxis] + + dp_reshaped = dyadic_products.reshape(D * D, N) + E = np.zeros((N, K)) + + for k in range(K): + Bk = 0.5 * (B[:, :, k] + B[:, :, k].conj().T) + eigenvalues, U = np.linalg.eigh(Bk) + + idx = np.argsort(eigenvalues)[::-1] + Lambda = np.real(eigenvalues[idx]) + U = U[:, idx] + + if np.any(Lambda > 1.0): + Lambda_perturbed = Lambda + np.arange(1, D + 1) * 1e-2 + Lambda_shifted = Lambda_perturbed - Lambda_perturbed.max() + c_diag = _complex_bingham_first_order_moments(Lambda_shifted, D) + cov_k = U @ np.diag(c_diag) @ U.conj().T + else: + cov_k = np.eye(D, dtype=complex) / D + + cov_vec = cov_k.ravel(order="C") + E[:, k] = np.real(dp_reshaped.T @ cov_vec.conj()) + + return E + + +# --------------------------------------------------------------------------- +# Helpers: first-order moments and simplex integral +# --------------------------------------------------------------------------- + + +def _complex_bingham_first_order_moments(Lambda_shifted, D): + """ + Compute E[|z_i|^2] for a diagonal complex Bingham with shifted eigenvalues. + + Uses numerical differentiation of log(int_simplex exp(Lambda.s) ds). + + Args: + Lambda_shifted: D-dim real array, shifted so max = 0. + D (int): Dimension. + + Returns: + ndarray: D-dim non-negative real array normalised to sum 1. + """ + Lambda = np.asarray(Lambda_shifted, dtype=float) + eps = 1e-5 + log_F0 = np.log(max(_simplex_integral(Lambda), 1e-300)) + moments = np.zeros(D) + for i in range(D): + L_plus = Lambda.copy() + L_plus[i] += eps + log_F_plus = np.log(max(_simplex_integral(L_plus), 1e-300)) + moments[i] = (log_F_plus - log_F0) / eps + + total = moments.sum() + if total > 1e-10: + moments /= total + else: + moments = np.ones(D) / D + return moments + + +def _simplex_integral(Lambda): + """ + Compute int_{standard simplex} exp(Lambda . s) ds. + + Uses the partial-fractions / divided-differences formula: + I = sum_i exp(Lambda_i) / prod_{j!=i} (Lambda_i - Lambda_j) + + Args: + Lambda: D-dim real array of eigenvalues. + + Returns: + float: Integral value (always positive). + """ + Lambda = np.asarray(Lambda, dtype=float).copy() + D = len(Lambda) + + if D == 1: + return float(np.exp(Lambda[0])) + + # Tiny perturbation to resolve exact degeneracy + Lambda = Lambda + np.arange(D) * 1e-10 + + result = 0.0 + for i in range(D): + denom = 1.0 + for j in range(D): + if j != i: + denom *= Lambda[i] - Lambda[j] + if abs(denom) > 1e-300: + result += np.exp(Lambda[i]) / denom + return float(result) diff --git a/pyrecest/distributions/hypersphere_subset/complex_watson_distribution.py b/pyrecest/distributions/hypersphere_subset/complex_watson_distribution.py new file mode 100644 index 000000000..b55d1ca83 --- /dev/null +++ b/pyrecest/distributions/hypersphere_subset/complex_watson_distribution.py @@ -0,0 +1,344 @@ +""" +Complex Watson distribution on the complex unit sphere in C^D. + +Reference: + Mardia, K. V. & Dryden, I. L. + The Complex Watson Distribution and Shape Analysis + Journal of the Royal Statistical Society: Series B + (Statistical Methodology), Blackwell Publishers Ltd., 1999, 61, 913-926. +""" + +import math + +import numpy as np +from scipy.optimize import brentq +from scipy.special import gammaln, hyp1f1 + + +class ComplexWatsonDistribution: + """ + Complex Watson distribution on the complex unit sphere in C^D. + + The PDF is: f(z; mu, kappa) = C(D, kappa)^{-1} * exp(kappa * |mu^H z|^2) + + where z in C^D is a unit vector (|z|=1), mu in C^D is the mode (unit vector), + kappa >= 0 is the concentration parameter, and C(D, kappa) is the normalization. + """ + + EPSILON = 1e-6 + + def __init__(self, mu, kappa): + """ + Initializes the ComplexWatsonDistribution. + + Args: + mu: D-dimensional complex unit vector (the mode direction). + kappa (float): Concentration parameter (>= 0). + """ + mu = np.asarray(mu, dtype=complex) + assert mu.ndim == 1, "mu must be a 1-D vector" + assert ( + abs(np.linalg.norm(mu) - 1.0) < self.EPSILON + ), "mu must be normalized (|mu|=1)" + + self.mu = mu + self.kappa = float(kappa) + self.dim = len(mu) # D: dimension of the complex space C^D + self._log_c = ComplexWatsonDistribution.log_norm(self.dim, self.kappa) + + @staticmethod + def log_norm(D, kappa): + """ + Compute the log normalization constant for the complex Watson distribution. + + Returns -log(C(D, kappa)) where + log C(D, kappa) = log(2) + D*log(pi) - log(Gamma(D)) + log(1F1(1; D; kappa)) + + Three regimes are used for numerical stability: + - Low kappa (kappa < 1/D): Taylor series + - Medium kappa (1/D <= kappa < 100): intermediate correction + - High kappa (kappa >= 100): asymptotic approximation + + Args: + D (int): Dimension of the complex space. + kappa: Concentration parameter(s) — scalar or array. + + Returns: + float or ndarray: -log(C(D, kappa)), same shape as kappa. + """ + scalar_input = np.ndim(kappa) == 0 + kappa = np.atleast_1d(np.asarray(kappa, dtype=float)).ravel() + log_c = np.zeros_like(kappa) + + # Asymptotic formula for high kappa + # log C ~ log(2) + D*log(pi) + (1-D)*log(kappa) + kappa + # log_c_high is evaluated for all kappa before masking; clip to avoid log(0) warning + log_c_high = ( + math.log(2) + D * math.log(math.pi) + + (1 - D) * np.log(np.maximum(kappa, 1e-300)) + kappa + ) + + # Intermediate formula (Mardia1999 Eq. 3): + # log C = log_c_high + log(1 - sum_{j=0}^{D-2} kappa^j * exp(-kappa) / j!) + running = np.exp(-kappa) + correction_sum = running.copy() + for j in range(1, D - 1): + running = running * kappa / j + correction_sum = correction_sum + running + if D >= 2: + log_c_medium = log_c_high + np.log( + np.maximum(1.0 - correction_sum, 1e-300) + ) + else: + log_c_medium = log_c_high + + # Taylor series for low kappa (Mardia1999 Eq. 4): + # 1F1(1; D; kappa) = 1 + kappa/D + kappa^2/(D*(D+1)) + ... + running_prod = np.ones_like(kappa) + series_sum = np.ones_like(kappa) + for j in range(10): + running_prod = running_prod * kappa / (D + j) + series_sum = series_sum + running_prod + log_c_low = ( + math.log(2) + D * math.log(math.pi) - gammaln(D) + np.log(series_sum) + ) + + mask_low = kappa < 1.0 / D + mask_high = kappa >= 100.0 + mask_medium = ~mask_low & ~mask_high + + log_c[mask_low] = log_c_low[mask_low] + log_c[mask_medium] = log_c_medium[mask_medium] + log_c[mask_high] = log_c_high[mask_high] + + result = -log_c + if scalar_input: + return float(result[0]) + return result + + def pdf(self, Z): + """ + Evaluate the PDF at the columns of Z. + + Args: + Z: D x N complex matrix where each column is a unit vector in C^D. + May also be a 1-D array (single vector). + + Returns: + ndarray: N-dimensional real array of PDF values. + """ + Z = np.asarray(Z, dtype=complex) + if Z.ndim == 1: + Z = Z.reshape(-1, 1) + # |mu^H z|^2 for each column + inner = np.abs(self.mu.conj() @ Z) ** 2 + return np.real(np.exp(self._log_c + self.kappa * inner)) + + def sample(self, n): + """ + Draw n unit vectors from the complex Watson distribution. + + Uses the complex Bingham representation: + B = -kappa * (I - mu mu^H) + + Args: + n (int): Number of samples. + + Returns: + ndarray: D x n complex matrix of samples. + """ + B = -self.kappa * ( + np.eye(self.dim, dtype=complex) - np.outer(self.mu, self.mu.conj()) + ) + B = 0.5 * (B + B.conj().T) + return _sample_complex_bingham(B, n) + + @staticmethod + def fit(Z, weights=None): + """ + Fit a ComplexWatsonDistribution to data using MLE. + + Args: + Z: D x N complex matrix of observations (unit vectors). + weights: Optional 1 x N or (N,) real weight array. + + Returns: + ComplexWatsonDistribution: Fitted distribution. + """ + mu_hat, kappa_hat = ComplexWatsonDistribution.estimate_parameters(Z, weights) + return ComplexWatsonDistribution(mu_hat, kappa_hat) + + @staticmethod + def estimate_parameters(Z, weights=None): + """ + MLE estimation of complex Watson parameters. + + Method: Mardia & Dryden (1999), Section 4. + + Args: + Z: D x N complex matrix of observations (unit vectors). + weights: Optional 1 x N or (N,) real weight array. + + Returns: + tuple: (mu_hat, kappa_hat) — complex unit mode vector and concentration. + """ + Z = np.asarray(Z, dtype=complex) + D, N = Z.shape + + if weights is None: + S = Z @ Z.conj().T + else: + weights = np.asarray(weights, dtype=float).ravel() + assert len(weights) == N, "weights length must match number of samples" + S = (Z * weights) @ Z.conj().T * N / np.sum(weights) + + S = 0.5 * (S + S.conj().T) # enforce Hermitian + + eigenvalues, eigenvectors = np.linalg.eigh(S) + idx = np.argmax(eigenvalues) + mu_hat = eigenvectors[:, idx] + lambda_max = float(np.real(eigenvalues[idx])) + + normed_lambda = lambda_max / N + + # High-concentration approximation (Mardia & Dryden 1999) + kappa_approx = N * (D - 1) / max(N - lambda_max, 1e-300) + if kappa_approx < 200: + kappa_hat = ComplexWatsonDistribution._hypergeometric_ratio_inverse( + normed_lambda, D, concentration_max=1000 + ) + else: + kappa_hat = kappa_approx + + return mu_hat, kappa_hat + + @staticmethod + def _hypergeometric_ratio(kappa, D): + """ + Compute E[|mu^H z|^2] = d(log C)/d(kappa) for the complex Watson distribution. + + This equals (1/D) * 1F1(2; D+1; kappa) / 1F1(1; D; kappa). + + Args: + kappa (float): Concentration. + D (int): Dimension of the complex space. + + Returns: + float: Expected value of |mu^H z|^2, in [1/D, 1). + """ + kappa = float(kappa) + if kappa < 1e-10: + return 1.0 / D + # For large kappa use asymptotic: ratio ~ 1 - (D-1)/kappa + # (from d/dkappa [kappa + (1-D)*log(kappa) + ...] = 1 + (1-D)/kappa) + if kappa >= 100.0: + return 1.0 - float(D - 1) / kappa + # Differentiate log_norm numerically to avoid hyp1f1 overflow + eps = max(kappa * 1e-4, 1e-7) + log_c_plus = ComplexWatsonDistribution.log_norm(D, kappa + eps) + log_c_minus = ComplexWatsonDistribution.log_norm(D, max(kappa - eps, 0.0)) + # ratio = d(log C)/d(kappa) = -d(log_norm)/d(kappa) [log_norm = -log C] + return -(log_c_plus - log_c_minus) / (2.0 * eps) + + @staticmethod + def _hypergeometric_ratio_inverse(r, D, concentration_max=500): + """ + Find kappa such that _hypergeometric_ratio(kappa, D) == r. + + Args: + r (float): Target ratio, should be in (1/D, 1). + D (int): Dimension of the complex space. + concentration_max (float): Upper bound for bracket search. + + Returns: + float: kappa value. + """ + r = float(r) + lower = 1.0 / D + if r <= lower + 1e-10: + return 0.0 + if r >= 1.0 - 1e-10: + return float(concentration_max) + + def objective(k): + return ComplexWatsonDistribution._hypergeometric_ratio(k, D) - r + + return brentq(objective, 0.0, float(concentration_max), xtol=1e-8) + + +# --------------------------------------------------------------------------- +# Sampling helpers +# --------------------------------------------------------------------------- + + +def _sample_complex_bingham(B, n): + """ + Sample n unit vectors from a complex Bingham distribution with parameter B. + + Implements the algorithm from: + Mardia, K. V. & Jupp, P. E. Directional Statistics, Wiley, 2009, p. 336. + + Args: + B: D x D complex Hermitian matrix (concentration matrix). + n (int): Number of samples. + + Returns: + ndarray: D x n complex matrix of samples (each column is a unit vector). + """ + B = np.asarray(B, dtype=complex) + D = B.shape[0] + B = 0.5 * (B + B.conj().T) + + # Eigendecompose -B (so eigenvalues are non-negative in descending order) + eigenvalues, V = np.linalg.eigh(-B) + idx = np.argsort(eigenvalues)[::-1] + Lambda = np.real(eigenvalues[idx]) + V = V[:, idx] + + # Shift so smallest eigenvalue is 0 (doesn't change the distribution) + Lambda = Lambda - Lambda[-1] + + Z = np.zeros((D, n), dtype=complex) + for i in range(n): + s = _sample_diagonal_complex_bingham_magnitudes(Lambda, D) + theta = 2.0 * np.pi * np.random.rand(D) + w = np.sqrt(s) * np.exp(1j * theta) + Z[:, i] = V @ w + + return Z + + +def _sample_diagonal_complex_bingham_magnitudes(Lambda, D): + """ + Sample squared magnitudes |z_i|^2 from a diagonal complex Bingham distribution. + + The resulting vector s satisfies s_i >= 0 and sum(s) = 1. + + Args: + Lambda: D-dimensional non-negative eigenvalue vector (largest first, last=0). + D (int): Dimension. + + Returns: + ndarray: D-dimensional vector of squared magnitudes. + """ + Lambda_pos = Lambda[: D - 1] # first D-1 (positive) eigenvalues + + # Precompute for the truncated exponential inverse CDF + large = Lambda_pos >= 0.03 + safe_lambda = np.where(large, Lambda_pos, 1.0) + temp1 = np.where(large, -1.0 / safe_lambda, 0.0) + temp2 = np.where(large, 1.0 - np.exp(-Lambda_pos), 0.0) + + s = np.zeros(D) + while True: + U = np.random.rand(D - 1) + if np.any(large): + s[: D - 1][large] = temp1[large] * np.log(1.0 - U[large] * temp2[large]) + if np.any(~large): + s[: D - 1][~large] = U[~large] + + if np.sum(s[: D - 1]) < 1.0: + break + + s[D - 1] = 1.0 - np.sum(s[: D - 1]) + return s diff --git a/pyrecest/tests/distributions/test_bayesian_complex_watson_mixture_model.py b/pyrecest/tests/distributions/test_bayesian_complex_watson_mixture_model.py new file mode 100644 index 000000000..705da6c92 --- /dev/null +++ b/pyrecest/tests/distributions/test_bayesian_complex_watson_mixture_model.py @@ -0,0 +1,198 @@ +import unittest + +import numpy as np +import numpy.testing as npt + +from pyrecest.distributions.hypersphere_subset.complex_watson_distribution import ( + ComplexWatsonDistribution, +) +from pyrecest.distributions.hypersphere_subset.bayesian_complex_watson_mixture_model import ( + BayesianComplexWatsonMixtureModel, + _simplex_integral, + _complex_bingham_first_order_moments, +) + + +def _make_unit_vectors(D, N, rng): + Z = rng.standard_normal((D, N)) + 1j * rng.standard_normal((D, N)) + Z /= np.linalg.norm(Z, axis=0) + return Z + + +class TestSimplexIntegralMixture(unittest.TestCase): + def test_D2(self): + a, b = 2.0, 1.0 + expected = (np.exp(a) - np.exp(b)) / (a - b) + self.assertAlmostEqual(_simplex_integral(np.array([a, b])), expected, places=8) + + def test_D3(self): + Lambda = np.array([2.0, 1.0, 0.0]) + expected = np.exp(2) / 2 - np.exp(1) + 0.5 + self.assertAlmostEqual(_simplex_integral(Lambda), expected, places=5) + + +class TestComplexBinghamFirstOrderMoments(unittest.TestCase): + def test_uniform_zero_eigenvalues(self): + D = 3 + Lambda = np.zeros(D) + moments = _complex_bingham_first_order_moments(Lambda, D) + npt.assert_allclose(moments, np.ones(D) / D, atol=1e-2) + + def test_sum_to_one(self): + D = 4 + Lambda = np.array([-1.0, -2.0, -3.0, -4.0]) + Lambda -= Lambda.max() + moments = _complex_bingham_first_order_moments(Lambda, D) + self.assertAlmostEqual(moments.sum(), 1.0, places=5) + + def test_largest_eigenvalue_largest_moment(self): + D = 3 + Lambda = np.array([0.0, -5.0, -10.0]) + moments = _complex_bingham_first_order_moments(Lambda, D) + self.assertGreater(moments[0], moments[1]) + self.assertGreater(moments[1], moments[2]) + + +class TestQuadraticExpectation(unittest.TestCase): + def test_identity_input_equals_moments_sum(self): + """E[z^H I z] = E[|z|^2] = 1 by definition.""" + D = 3 + K = 2 + rng = np.random.default_rng(7) + B = np.zeros((D, D, K), dtype=complex) + I_3d = np.eye(D, dtype=complex)[:, :, np.newaxis] + E = BayesianComplexWatsonMixtureModel.quadratic_expectation(I_3d, B) + npt.assert_allclose(E, np.ones((1, K)), atol=1e-8) + + def test_shape(self): + D = 4 + N = 10 + K = 3 + rng = np.random.default_rng(1) + Z = _make_unit_vectors(D, N, rng) + dp = Z[:, np.newaxis, :] * Z.conj()[np.newaxis, :, :] + dp = dp.reshape(D, D, N) + B = np.zeros((D, D, K), dtype=complex) + E = BayesianComplexWatsonMixtureModel.quadratic_expectation(dp, B) + self.assertEqual(E.shape, (N, K)) + + def test_real_output(self): + D = 3 + N = 5 + K = 2 + rng = np.random.default_rng(2) + Z = _make_unit_vectors(D, N, rng) + dp = (Z[:, np.newaxis, :] * Z.conj()[np.newaxis, :, :]).reshape(D, D, N) + B = np.zeros((D, D, K), dtype=complex) + E = BayesianComplexWatsonMixtureModel.quadratic_expectation(dp, B) + self.assertTrue(np.all(np.isreal(E))) + + +class TestBayesianComplexWatsonMixtureModelConstructor(unittest.TestCase): + def test_basic_construction(self): + D, K = 3, 2 + B = np.zeros((D, D, K), dtype=complex) + concentrations = np.array([5.0, 10.0]) + alpha = np.array([1.0, 1.0]) + model = BayesianComplexWatsonMixtureModel(B, concentrations, alpha) + self.assertEqual(model.K, K) + self.assertEqual(model.dim, D) + + def test_non_hermitian_B_raises(self): + D, K = 2, 1 + B = np.ones((D, D, K), dtype=complex) * (1 + 1j) + with self.assertRaises(AssertionError): + BayesianComplexWatsonMixtureModel(B, np.array([1.0]), np.array([1.0])) + + +class TestParametersDefault(unittest.TestCase): + def test_keys_present(self): + params = BayesianComplexWatsonMixtureModel.parameters_default(4, 3) + self.assertIn("initial", params) + self.assertIn("prior", params) + self.assertIn("I", params) + self.assertIn("B", params["initial"]) + self.assertIn("kappa", params["initial"]) + self.assertIn("alpha", params["initial"]) + + def test_shapes(self): + D, K = 5, 4 + params = BayesianComplexWatsonMixtureModel.parameters_default(D, K) + self.assertEqual(params["initial"]["B"].shape, (D, D, K)) + self.assertEqual(len(params["initial"]["alpha"]), K) + + +class TestFitDefault(unittest.TestCase): + def test_fit_returns_model(self): + rng = np.random.default_rng(42) + D, K, N = 3, 2, 50 + Z = _make_unit_vectors(D, N, rng) + model, posterior = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + self.assertIsInstance(model, BayesianComplexWatsonMixtureModel) + self.assertEqual(model.K, K) + self.assertEqual(model.dim, D) + + def test_posterior_keys(self): + rng = np.random.default_rng(0) + D, K, N = 3, 2, 30 + Z = _make_unit_vectors(D, N, rng) + _, posterior = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + for key in ("B", "kappa", "alpha", "gamma"): + self.assertIn(key, posterior) + + def test_gamma_sums_to_one(self): + rng = np.random.default_rng(1) + D, K, N = 3, 2, 40 + Z = _make_unit_vectors(D, N, rng) + _, posterior = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + gamma = posterior["gamma"] + npt.assert_allclose(gamma.sum(axis=1), np.ones(N), atol=1e-8) + + def test_alpha_positive(self): + rng = np.random.default_rng(2) + D, K, N = 3, 2, 40 + Z = _make_unit_vectors(D, N, rng) + _, posterior = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + self.assertTrue(np.all(posterior["alpha"] > 0)) + + def test_kappa_nonnegative(self): + rng = np.random.default_rng(3) + D, K, N = 3, 3, 60 + Z = _make_unit_vectors(D, N, rng) + _, posterior = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + self.assertTrue(np.all(posterior["kappa"] >= 0)) + + def test_B_hermitian_after_fit(self): + rng = np.random.default_rng(4) + D, K, N = 3, 2, 40 + Z = _make_unit_vectors(D, N, rng) + model, _ = BayesianComplexWatsonMixtureModel.fit_default(Z, K) + for k in range(K): + npt.assert_allclose( + model.B[:, :, k], + model.B[:, :, k].conj().T, + atol=1e-8, + err_msg=f"B[:,:,{k}] is not Hermitian", + ) + + def test_fit_two_cluster_recovery(self): + """Fit on data from two distinct clusters should assign high weight to each.""" + rng = np.random.default_rng(99) + D = 3 + # Two orthogonal modes + mu1 = np.array([1.0, 0.0, 0.0], dtype=complex) + mu2 = np.array([0.0, 1.0, 0.0], dtype=complex) + dist1 = ComplexWatsonDistribution(mu1, 20.0) + dist2 = ComplexWatsonDistribution(mu2, 20.0) + Z = np.concatenate([dist1.sample(60), dist2.sample(60)], axis=1) + K = 2 + params = BayesianComplexWatsonMixtureModel.parameters_default(D, K) + params["I"] = 50 + _, posterior = BayesianComplexWatsonMixtureModel.fit(Z, params) + # Both components should have non-trivial assignment + N_k = posterior["gamma"].sum(axis=0) + self.assertGreater(min(N_k), 10.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/pyrecest/tests/distributions/test_complex_watson_distribution.py b/pyrecest/tests/distributions/test_complex_watson_distribution.py new file mode 100644 index 000000000..9172d1a22 --- /dev/null +++ b/pyrecest/tests/distributions/test_complex_watson_distribution.py @@ -0,0 +1,151 @@ +import unittest + +import numpy as np +import numpy.testing as npt + +from pyrecest.distributions.hypersphere_subset.complex_watson_distribution import ( + ComplexWatsonDistribution, +) + + +def _random_unit_vector(D, rng=None): + """Return a random unit vector in C^D.""" + rng = rng or np.random.default_rng(42) + z = rng.standard_normal(D) + 1j * rng.standard_normal(D) + return z / np.linalg.norm(z) + + +class TestComplexWatsonLogNorm(unittest.TestCase): + """Tests for the log normalisation constant.""" + + def test_scalar_input(self): + val = ComplexWatsonDistribution.log_norm(3, 5.0) + self.assertIsInstance(val, float) + + def test_array_input(self): + kappas = np.array([0.1, 1.0, 10.0, 200.0]) + vals = ComplexWatsonDistribution.log_norm(3, kappas) + self.assertEqual(vals.shape, (4,)) + + def test_D2_low_kappa(self): + # For low kappa, log C ~ log(2) + 2*log(pi) - log(Gamma(2)) + log(1) = log(2 pi^2) + # so log_c = -log(2*pi^2) + log_c = ComplexWatsonDistribution.log_norm(2, 1e-10) + expected = -np.log(2 * np.pi**2) + self.assertAlmostEqual(log_c, expected, places=5) + + def test_continuity_across_regimes(self): + # Values should be continuous at regime boundaries 1/D and 100 + D = 3 + eps = 1e-4 + # Boundary kappa ~ 1/D = 1/3 + v1 = ComplexWatsonDistribution.log_norm(D, 1.0 / D - eps) + v2 = ComplexWatsonDistribution.log_norm(D, 1.0 / D + eps) + self.assertAlmostEqual(v1, v2, places=2) + # Boundary kappa ~ 100 + v3 = ComplexWatsonDistribution.log_norm(D, 100.0 - eps) + v4 = ComplexWatsonDistribution.log_norm(D, 100.0 + eps) + self.assertAlmostEqual(v3, v4, places=2) + + +class TestComplexWatsonDistribution(unittest.TestCase): + """Tests for ComplexWatsonDistribution.""" + + def setUp(self): + self.rng = np.random.default_rng(0) + + def _unit_mu(self, D): + z = self.rng.standard_normal(D) + 1j * self.rng.standard_normal(D) + return z / np.linalg.norm(z) + + def test_constructor(self): + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 2.0) + self.assertEqual(dist.dim, D) + self.assertAlmostEqual(dist.kappa, 2.0) + npt.assert_array_almost_equal(dist.mu, mu) + + def test_pdf_positive(self): + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 5.0) + Z = np.stack([_random_unit_vector(D, self.rng) for _ in range(10)], axis=1) + p = dist.pdf(Z) + self.assertTrue(np.all(p > 0)) + + def test_pdf_mode_is_maximum(self): + """The PDF at the mode (mu) should be >= the PDF at any other point.""" + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 10.0) + Z = np.stack([_random_unit_vector(D, self.rng) for _ in range(50)], axis=1) + p_mode = dist.pdf(mu.reshape(-1, 1))[0] + self.assertTrue(np.all(p_mode >= dist.pdf(Z) - 1e-10)) + + def test_pdf_antipodal_symmetry(self): + """The PDF should be the same at z and -z (|mu^H z|^2 = |mu^H (-z)|^2).""" + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 5.0) + z = _random_unit_vector(D, self.rng).reshape(-1, 1) + npt.assert_allclose(dist.pdf(z), dist.pdf(-z), rtol=1e-10) + + def test_pdf_phase_invariance(self): + """Multiplying z by a complex phase should not change the PDF.""" + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 5.0) + z = _random_unit_vector(D, self.rng).reshape(-1, 1) + phase = np.exp(1j * 1.23) + npt.assert_allclose(dist.pdf(z), dist.pdf(phase * z), rtol=1e-10) + + def test_hypergeometric_ratio_bounds(self): + D = 4 + r0 = ComplexWatsonDistribution._hypergeometric_ratio(0.0, D) + r_large = ComplexWatsonDistribution._hypergeometric_ratio(1000.0, D) + self.assertAlmostEqual(r0, 1.0 / D, places=5) + self.assertGreater(r_large, 0.99) + self.assertLessEqual(r_large, 1.0) + + def test_hypergeometric_ratio_inverse(self): + D = 3 + for kappa in [0.5, 5.0, 50.0]: + r = ComplexWatsonDistribution._hypergeometric_ratio(kappa, D) + kappa_hat = ComplexWatsonDistribution._hypergeometric_ratio_inverse(r, D) + self.assertAlmostEqual(kappa_hat, kappa, places=3) + + def test_sample_on_unit_sphere(self): + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 5.0) + Z = dist.sample(200) + norms = np.abs(np.einsum("di,di->i", Z.conj(), Z)) + npt.assert_allclose(norms, np.ones(200), atol=1e-10) + + def test_sample_concentrated_near_mode(self): + """With high kappa, |mu^H z|^2 should be close to 1.""" + D = 3 + mu = self._unit_mu(D) + dist = ComplexWatsonDistribution(mu, 200.0) + Z = dist.sample(100) + inner_sq = np.abs(mu.conj() @ Z) ** 2 + self.assertGreater(inner_sq.mean(), 0.9) + + def test_fit_recovers_parameters(self): + """fit() on many samples should approximately recover mu and kappa.""" + D = 3 + mu = self._unit_mu(D) + kappa = 8.0 + dist = ComplexWatsonDistribution(mu, kappa) + Z = dist.sample(1000) + dist_hat = ComplexWatsonDistribution.fit(Z) + # Mode mu is only identified up to a global phase rotation + ip = abs(dist_hat.mu.conj() @ mu) + self.assertGreater(ip, 0.99) + self.assertAlmostEqual(dist_hat.kappa, kappa, delta=2.0) + + + +if __name__ == "__main__": + unittest.main()