diff --git a/pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py b/pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py index 03ed421a7..fc4bba5ed 100644 --- a/pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py +++ b/pyrecest/distributions/hypersphere_subset/spherical_harmonics_distribution_complex.py @@ -1,7 +1,9 @@ +import warnings + +import pyrecest import scipy # pylint: disable=redefined-builtin,no-name-in-module,no-member -# pylint: disable=no-name-in-module,no-member from pyrecest.backend import ( abs, all, @@ -14,7 +16,9 @@ full, imag, isnan, + zeros_like, linalg, + maximum, pi, real, reshape, @@ -22,6 +26,9 @@ sin, sqrt, zeros, + cos, + meshgrid, + deg2rad, ) # pylint: disable=E0611 @@ -37,6 +44,9 @@ class SphericalHarmonicsDistributionComplex(AbstractSphericalHarmonicsDistribution): def __init__(self, coeff_mat, transformation="identity", assert_real=True): + assert ( + pyrecest.backend.__backend_name__ != "jax" # pylint: disable=no-member + ), "SphericalHarmonicsDistributionComplex is not supported on the JAX backend" AbstractSphericalHarmonicsDistribution.__init__(self, coeff_mat, transformation) self.assert_real = assert_real @@ -190,3 +200,261 @@ def imag_part(phi, theta, n, m): coeff_mat[n, m + n] = real_integral + 1j * imag_integral return SphericalHarmonicsDistributionComplex(coeff_mat, transformation) + + # ------------------------------------------------------------------ + # pyshtools-based helper methods + # ------------------------------------------------------------------ + + @staticmethod + def _coeff_mat_to_pysh(coeff_mat, degree): + """Convert our coeff_mat to a pyshtools SHComplexCoeffs object.""" + import pyshtools as pysh # pylint: disable=import-error + + clm = pysh.SHCoeffs.from_zeros( + degree, kind="complex", normalization="ortho", csphase=-1 + ) + for n in range(degree + 1): + for m in range(0, n + 1): + clm.coeffs[0, n, m] = coeff_mat[n, n + m] + for m in range(1, n + 1): + clm.coeffs[1, n, m] = coeff_mat[n, n - m] + return clm + + @staticmethod + def _pysh_to_coeff_mat(clm, degree): + """Convert a pyshtools SHComplexCoeffs object to our coeff_mat.""" + coeff_mat = zeros((degree + 1, 2 * degree + 1), dtype=complex128) + max_n = min(clm.lmax, degree) + for n in range(max_n + 1): + for m in range(0, n + 1): + coeff_mat[n, n + m] = clm.coeffs[0, n, m] + for m in range(1, n + 1): + coeff_mat[n, n - m] = clm.coeffs[1, n, m] + return coeff_mat + + @staticmethod + def _get_dh_grid_cartesian(degree): + """Return (x, y, z) flat arrays and grid_shape for the DH grid at *degree*.""" + import pyshtools as pysh # pylint: disable=import-error + + dummy = pysh.SHCoeffs.from_zeros( + degree, kind="complex", normalization="ortho", csphase=-1 + ) + grid = dummy.expand(grid="DH", extend=False) + lats, lons = grid.lats(), grid.lons() + lon_mesh, lat_mesh = meshgrid(array(lons), array(lats)) + theta = deg2rad(90.0 - lat_mesh) # colatitude in radians + phi = deg2rad(lon_mesh) # azimuth in radians + x_c = sin(theta) * cos(phi) + y_c = sin(theta) * sin(phi) + z_c = cos(theta) + return x_c.ravel(), y_c.ravel(), z_c.ravel(), theta.shape + + def _eval_on_grid(self, target_degree=None): + """Evaluate this SHD on the DH grid at *target_degree* (defaults to own degree). + + The DH grid is expanded at *target_degree* so that higher-frequency grids + can be used for intermediate computations (e.g. when squaring a degree-L + function that has degree 2L). + Returns a real 2-D numpy array of shape (nlat, nlon). + """ + import pyshtools as pysh # pylint: disable=import-error + + degree = self.coeff_mat.shape[0] - 1 + if target_degree is None: + target_degree = degree + + # Pad our coefficients into a pysh object at target_degree + clm_full = pysh.SHCoeffs.from_zeros( + target_degree, kind="complex", normalization="ortho", csphase=-1 + ) + min_deg = min(degree, target_degree) + clm_own = self._coeff_mat_to_pysh(self.coeff_mat, min_deg) + clm_full.coeffs[0, : min_deg + 1, : min_deg + 1] = clm_own.coeffs[ + 0, : min_deg + 1, : min_deg + 1 + ] + clm_full.coeffs[1, : min_deg + 1, : min_deg + 1] = clm_own.coeffs[ + 1, : min_deg + 1, : min_deg + 1 + ] + grid = clm_full.expand(grid="DH", extend=False) + return array(grid.data.real) + + @staticmethod + def _fit_from_grid(grid_vals_real, degree, transformation): + """Fit SH coefficients to real-valued grid values on a DH grid. + + *grid_vals_real* is a 2D array (numpy or backend tensor) with shape + matching the DH grid for *degree*. Returns a new + :class:`SphericalHarmonicsDistributionComplex`. + """ + import numpy as _np # noqa: PLC0415 + import pyshtools as pysh # pylint: disable=import-error + + grid_vals_np = _np.asarray(grid_vals_real) + grid_obj = pysh.SHGrid.from_array( + grid_vals_np.astype(complex), grid="DH" + ) + clm = grid_obj.expand(lmax_calc=degree, normalization="ortho", csphase=-1) + coeff_mat = SphericalHarmonicsDistributionComplex._pysh_to_coeff_mat(clm, degree) + shd = SphericalHarmonicsDistributionComplex(coeff_mat, transformation) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + shd.normalize_in_place() + return shd + + # ------------------------------------------------------------------ + # Public numerical methods + # ------------------------------------------------------------------ + + @staticmethod + def from_distribution_numerical_fast(dist, degree, transformation="identity"): + """Approximate *dist* by a degree-*degree* SHD using a DH grid. + + This is faster than :meth:`from_distribution_via_integral` because it + uses the discrete spherical-harmonic transform instead of numerical + quadrature. + """ + if transformation not in ("identity", "sqrt"): + raise ValueError(f"Unsupported transformation: '{transformation}'") + + x_c, y_c, z_c, grid_shape = ( + SphericalHarmonicsDistributionComplex._get_dh_grid_cartesian(degree) + ) + xs = column_stack([x_c, y_c, z_c]) + fvals = array(dist.pdf(xs), dtype=float).reshape(grid_shape) + + if transformation == "sqrt": + fvals = sqrt(maximum(fvals, 0.0)) + + return SphericalHarmonicsDistributionComplex._fit_from_grid( + fvals, degree, transformation + ) + + def convolve(self, other): # pylint: disable=too-many-locals + """Spherical convolution with a *zonal* distribution *other*. + + For the ``'identity'`` transformation the standard frequency-domain + formula is used (exact for bandlimited functions). For the ``'sqrt'`` + transformation a grid-based approach with a 2× finer intermediate grid + is used so that squaring the sqrt functions introduces no aliasing. + """ + assert isinstance( + other, SphericalHarmonicsDistributionComplex + ), "other must be a SphericalHarmonicsDistributionComplex" + + degree = self.coeff_mat.shape[0] - 1 + + if self.transformation == "identity" and other.transformation == "identity": + # Direct frequency-domain formula: h_{n,m} = sqrt(4π/(2n+1)) * f_{n,m} * g_{n,0} + h_lm = zeros_like(self.coeff_mat) + for n in range(degree + 1): + factor = ( + sqrt(4.0 * pi / (2 * n + 1)) + * other.coeff_mat[n, n] + ) + for m in range(-n, n + 1): + h_lm[n, n + m] = factor * self.coeff_mat[n, n + m] + result = SphericalHarmonicsDistributionComplex(h_lm, "identity") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + result.normalize_in_place() + return result + + if self.transformation == "sqrt" and other.transformation == "sqrt": + # Use a grid twice as fine to avoid aliasing when squaring. + degree_fine = 2 * degree + + # Recover p = f^2 and q = g^2 on the fine grid, then SHT to degree. + f_grid = self._eval_on_grid(target_degree=degree_fine) + p_grid = f_grid**2 + + g_grid = other._eval_on_grid(target_degree=degree_fine) # pylint: disable=protected-access + q_grid = g_grid**2 + + import numpy as _np # noqa: PLC0415 + import pyshtools as pysh # pylint: disable=import-error + + def _grid_to_coeff(grid_vals): + grid_vals_np = _np.asarray(grid_vals) + grid_obj = pysh.SHGrid.from_array( + grid_vals_np.astype(complex), grid="DH" + ) + clm = grid_obj.expand( + lmax_calc=degree, normalization="ortho", csphase=-1 + ) + return self._pysh_to_coeff_mat(clm, degree) + + p_lm = _grid_to_coeff(p_grid) + q_lm = _grid_to_coeff(q_grid) + + # Convolution formula on the identity coefficients + r_lm = zeros_like(p_lm) + for n in range(degree + 1): + factor = sqrt(4.0 * pi / (2 * n + 1)) * q_lm[n, n] + for m in range(-n, n + 1): + r_lm[n, n + m] = factor * p_lm[n, n + m] + + # Evaluate r on the standard DH grid, take sqrt, refit + r_shd_id = SphericalHarmonicsDistributionComplex(r_lm, "identity") + r_grid = r_shd_id._eval_on_grid() # pylint: disable=protected-access + sqrt_r_grid = sqrt(maximum(r_grid, 0.0)) + + return SphericalHarmonicsDistributionComplex._fit_from_grid( + sqrt_r_grid, degree, "sqrt" + ) + + raise ValueError( + "convolve: mixed transformations are not supported. " + "Both self and other must use the same transformation." + ) + + def multiply(self, other): + """Pointwise multiplication in physical space (Bayesian update step). + + Works for both ``'identity'`` and ``'sqrt'`` transformations. + For ``'sqrt'``: ``sqrt(p) * sqrt(q) = sqrt(p*q)`` which is the correct + sqrt-transformed product density. + """ + assert isinstance( + other, SphericalHarmonicsDistributionComplex + ), "other must be a SphericalHarmonicsDistributionComplex" + assert self.transformation == other.transformation, ( + "multiply: both distributions must use the same transformation" + ) + + degree = self.coeff_mat.shape[0] - 1 + + f_grid = self._eval_on_grid() + g_grid = other._eval_on_grid() # pylint: disable=protected-access + + h_grid = f_grid * g_grid + + return SphericalHarmonicsDistributionComplex._fit_from_grid( + h_grid, degree, self.transformation + ) + + def rotate(self, alpha, beta, gamma): + """Rotate the distribution by ZYZ Euler angles (in radians). + + Parameters + ---------- + alpha : float + First rotation angle around Z (radians). + beta : float + Second rotation angle around Y (radians). + gamma : float + Third rotation angle around Z (radians). + """ + degree = self.coeff_mat.shape[0] - 1 + clm = self._coeff_mat_to_pysh(self.coeff_mat, degree) + clm_rot = clm.rotate( + alpha * 180.0 / pi, + beta * 180.0 / pi, + gamma * 180.0 / pi, + degrees=True, + body=True, + ) + coeff_mat_rot = self._pysh_to_coeff_mat(clm_rot, degree) + return SphericalHarmonicsDistributionComplex( + coeff_mat_rot, self.transformation + ) diff --git a/pyrecest/filters/__init__.py b/pyrecest/filters/__init__.py index 3d1ab91d3..b74b91989 100644 --- a/pyrecest/filters/__init__.py +++ b/pyrecest/filters/__init__.py @@ -71,6 +71,8 @@ "HypertoroidalFilterMixin", "HypertoroidalParticleFilter", "KalmanFilter", + "EuclideanParticleFilter", + "VonMisesFisherFilter", "KernelSMEFilter", "LinBoundedFilterMixin", "LinBoundedParticleFilter", diff --git a/pyrecest/filters/spherical_harmonics_filter.py b/pyrecest/filters/spherical_harmonics_filter.py new file mode 100644 index 000000000..6370efaba --- /dev/null +++ b/pyrecest/filters/spherical_harmonics_filter.py @@ -0,0 +1,234 @@ +import copy +import warnings + +from pyrecest.backend import zeros, sqrt, pi, linalg, abs, stack, arccos, arctan2, clip, array, ones, maximum # pylint: disable=redefined-builtin + +from pyrecest.distributions.hypersphere_subset.spherical_harmonics_distribution_complex import ( + SphericalHarmonicsDistributionComplex, +) + +from .abstract_filter import AbstractFilter +from .manifold_mixins import HypersphericalFilterMixin + + +class SphericalHarmonicsFilter(AbstractFilter, HypersphericalFilterMixin): + """Filter on the unit sphere using spherical harmonic representations. + + Supports both the ``'identity'`` transformation (coefficients represent + the density directly) and the ``'sqrt'`` transformation (coefficients + represent the square-root of the density). + + References + ---------- + Florian Pfaff, Gerhard Kurz, and Uwe D. Hanebeck, + "Filtering on the Unit Sphere Using Spherical Harmonics", + Proceedings of the 2017 IEEE International Conference on Multisensor + Fusion and Integration for Intelligent Systems (MFI 2017), + Daegu, Korea, November 2017. + """ + + def __init__(self, degree, transformation="identity"): + HypersphericalFilterMixin.__init__(self) + coeff_mat = zeros((degree + 1, 2 * degree + 1), dtype=complex) + if transformation == "identity": + coeff_mat[0, 0] = 1.0 / sqrt(4.0 * pi) + elif transformation == "sqrt": + coeff_mat[0, 0] = 1.0 + else: + raise ValueError(f"Unknown transformation: '{transformation}'") + initial_state = SphericalHarmonicsDistributionComplex( + coeff_mat, transformation + ) + AbstractFilter.__init__(self, initial_state) + + # ------------------------------------------------------------------ + # filter_state / state property + # ------------------------------------------------------------------ + + @property + def filter_state(self): + return self._filter_state + + @filter_state.setter + def filter_state(self, new_state): + assert isinstance( + new_state, SphericalHarmonicsDistributionComplex + ), "filter_state must be a SphericalHarmonicsDistributionComplex" + self._filter_state = copy.deepcopy(new_state) + + @property + def state(self): + """Alias for :attr:`filter_state`.""" + return self._filter_state + + # ------------------------------------------------------------------ + # Public interface methods + # ------------------------------------------------------------------ + + def set_state(self, state): + """Set the filter state with optional warnings about mismatches.""" + assert isinstance( + state, SphericalHarmonicsDistributionComplex + ), "state must be a SphericalHarmonicsDistributionComplex" + if self._filter_state.transformation != state.transformation: + warnings.warn( + "setState:transDiffer: New density is transformed differently.", + stacklevel=2, + ) + if self._filter_state.coeff_mat.shape != state.coeff_mat.shape: + warnings.warn( + "setState:noOfCoeffsDiffer: New density has different number of " + "coefficients.", + stacklevel=2, + ) + self._filter_state = copy.deepcopy(state) + + def get_estimate(self): + """Return the current filter state.""" + return self._filter_state + + def get_estimate_mean(self): + """Return the mean direction of the current filter state.""" + return self._filter_state.mean_direction() + + # ------------------------------------------------------------------ + # Prediction + # ------------------------------------------------------------------ + + def predict_identity(self, sys_noise): + """Predict via spherical convolution with a *zonal* system noise SHD. + + Parameters + ---------- + sys_noise : SphericalHarmonicsDistributionComplex + Must be a zonal distribution (rotationally symmetric around the + z-axis) in the same transformation as the filter state. + """ + assert isinstance( + sys_noise, SphericalHarmonicsDistributionComplex + ), "sys_noise must be a SphericalHarmonicsDistributionComplex" + if ( + self._filter_state.transformation == "sqrt" + and sys_noise.transformation == "identity" + ): + state_degree = self._filter_state.coeff_mat.shape[0] - 1 + noise_degree = sys_noise.coeff_mat.shape[0] - 1 + assert 2 * state_degree == noise_degree, ( + "If the sqrt variant is used and sys_noise is given in " + "identity form, sys_noise should have degree 2 * state_degree." + ) + self._filter_state = self._filter_state.convolve(sys_noise) + + # ------------------------------------------------------------------ + # Update + # ------------------------------------------------------------------ + + def update_identity(self, meas_noise, z): + """Update by multiplying the state with a (possibly rotated) noise SHD. + + *meas_noise* should be a zonal SHD with its axis of symmetry along + [0, 0, 1]. If the measurement *z* differs from [0, 0, 1] the noise is + rotated to align with *z* before the multiplication. + + Parameters + ---------- + meas_noise : SphericalHarmonicsDistributionComplex + Zonal measurement noise (axis along [0, 0, 1]). + z : array-like, shape (3,) + Measurement direction on the unit sphere. + """ + assert isinstance( + meas_noise, SphericalHarmonicsDistributionComplex + ), "meas_noise must be a SphericalHarmonicsDistributionComplex" + z = array(z, dtype=float).ravel() + z_norm = linalg.norm(z) + not_near_north_pole = ( + abs(z[0]) > 1e-6 or abs(z[1]) > 1e-6 or abs(z[2] - 1.0) > 1e-6 + ) + if z_norm > 1e-6 and not_near_north_pole: + warnings.warn( + "SphericalHarmonicsFilter:rotationRequired: " + "Performance may be low for z != [0, 0, 1]. " + "Using update_nonlinear may yield faster results.", + stacklevel=2, + ) + phi = arctan2(z[1], z[0]) # azimuth + theta = arccos( + clip(z[2] / z_norm, -1.0, 1.0) + ) # colatitude + meas_noise = meas_noise.rotate(0.0, theta, phi) + self._filter_state = self._filter_state.multiply(meas_noise) + + def update_nonlinear(self, likelihood, z): + """Nonlinear Bayesian update via a likelihood function. + + Parameters + ---------- + likelihood : callable + ``likelihood(z, pts)`` where *pts* is a ``(3, N)`` matrix of + Cartesian coordinates on the unit sphere and the return value is a + length-N array of likelihood values. + z : array-like + Measurement (passed through to *likelihood* unchanged). + """ + self._update_nonlinear_impl([likelihood], [z]) + + def update_nonlinear_multiple(self, likelihoods, measurements): + """Nonlinear update using a list of likelihood functions simultaneously. + + Parameters + ---------- + likelihoods : list of callables + Each element is a likelihood function as described in + :meth:`update_nonlinear`. + measurements : list of array-like + Corresponding measurements. + """ + assert len(likelihoods) == len( + measurements + ), "likelihoods and measurements must have the same length" + self._update_nonlinear_impl(likelihoods, measurements) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _update_nonlinear_impl(self, likelihoods, measurements): # pylint: disable=too-many-locals + """Shared implementation for single and multiple nonlinear updates.""" + + degree = self._filter_state.coeff_mat.shape[0] - 1 + + # DH grid coordinates + x_c, y_c, z_c, grid_shape = ( + SphericalHarmonicsDistributionComplex._get_dh_grid_cartesian(degree) # pylint: disable=protected-access + ) + # (3, N) matrix for likelihood calls + grid_pts = stack([x_c, y_c, z_c], axis=0) + + # Evaluate current state on the DH grid + fval_curr = self._filter_state._eval_on_grid() # pylint: disable=protected-access + + # Accumulate likelihood values over all (likelihood, measurement) pairs + likelihood_vals = ones(grid_shape, dtype=float) + for lk, zk in zip(likelihoods, measurements): + lv = array(lk(zk, grid_pts), dtype=float).reshape(grid_shape) + likelihood_vals *= lv + + # Scale factor: multiplying by 2^n keeps values away from zero so that + # the SHT fit and subsequent normalisation remain numerically stable + # when the product likelihood is very small (e.g. many weak likelihoods). + # This factor is divided out implicitly by the normalisation step. + scale = float(2 ** len(likelihoods)) + + if self._filter_state.transformation == "identity": + fval_new = scale * fval_curr * likelihood_vals + elif self._filter_state.transformation == "sqrt": + fval_new = scale * fval_curr * sqrt(maximum(likelihood_vals, 0.0)) + else: + raise ValueError( + f"Unsupported transformation: '{self._filter_state.transformation}'" + ) + + self._filter_state = SphericalHarmonicsDistributionComplex._fit_from_grid( # pylint: disable=protected-access + fval_new, degree, self._filter_state.transformation + ) diff --git a/pyrecest/filters/von_mises_fisher_filter.py b/pyrecest/filters/von_mises_fisher_filter.py index 2ea742143..6e274da9a 100644 --- a/pyrecest/filters/von_mises_fisher_filter.py +++ b/pyrecest/filters/von_mises_fisher_filter.py @@ -24,6 +24,14 @@ def filter_state(self, filter_state): ), "filter_state must be an instance of VonMisesFisherDistribution." self._filter_state = filter_state + def set_state(self, state): + """Set the filter state.""" + self.filter_state = state + + def get_estimate_mean(self): + """Return the mean direction of the current filter state.""" + return self.filter_state.mean_direction() + def predict_identity(self, sys_noise): """ State prediction via mulitiplication. Provide zonal density for update diff --git a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py index 2383b78d2..0305aabb7 100644 --- a/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py +++ b/pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py @@ -2,6 +2,7 @@ import warnings import numpy.testing as npt +import pyrecest import pyrecest.backend from parameterized import parameterized @@ -36,6 +37,10 @@ ) +@unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on the JAX backend", +) class SphericalHarmonicsDistributionComplexTest(unittest.TestCase): def setUp(self): random.seed(1) diff --git a/pyrecest/tests/filters/test_spherical_harmonics_filter.py b/pyrecest/tests/filters/test_spherical_harmonics_filter.py new file mode 100644 index 000000000..fc1598de9 --- /dev/null +++ b/pyrecest/tests/filters/test_spherical_harmonics_filter.py @@ -0,0 +1,125 @@ +import unittest +import numpy as np +import numpy.testing as npt +from scipy.stats import norm + +import pyrecest +import pyrecest.backend +from pyrecest.backend import array +from pyrecest.distributions import VonMisesFisherDistribution, SphericalHarmonicsDistributionComplex +from pyrecest.filters import VonMisesFisherFilter +from pyrecest.filters.spherical_harmonics_filter import SphericalHarmonicsFilter + +_skip_jax = unittest.skipIf( + pyrecest.backend.__backend_name__ == "jax", # pylint: disable=no-member + reason="Not supported on the JAX backend", +) + + +class SphericalHarmonicsFilterTest(unittest.TestCase): + + @_skip_jax + def test_update_identity(self): + for transformation in ['identity', 'sqrt']: + shd_filter = SphericalHarmonicsFilter(30, transformation) + vmf_filter = VonMisesFisherFilter() + + vmf1 = VonMisesFisherDistribution(array([0.0, 1.0, 0.0]), 1) + vmf2 = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 0.1) + + shd1 = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(vmf1, 30, transformation) + shd2 = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(vmf2, 30, transformation) + + vmf_filter.set_state(vmf1) + vmf_filter.update_identity(vmf2, array([1.0, 0.0, 0.0])) + + shd_filter.set_state(shd1) + shd_filter.update_identity(shd2, array([1.0, 0.0, 0.0])) + + npt.assert_allclose(vmf_filter.get_estimate_mean(), shd_filter.get_estimate_mean(), atol=1e-10) + + @_skip_jax + def test_update_using_likelihood(self): + np.random.seed(1) + + pos_true = -1 / np.sqrt(3) * np.ones(3) + + # Generate measurements according to truncated gaussian along x, y, and z axis + sigma_x = 0.3 + sigma_y = 0.3 + sigma_z = 0.3 + + meas_x = self._generate_truncated_normals(pos_true[0], sigma_x, 5) + meas_y = self._generate_truncated_normals(pos_true[1], sigma_y, 5) + meas_z = self._generate_truncated_normals(pos_true[2], sigma_z, 5) + + for transformation in ['identity', 'sqrt']: + sh_filter = SphericalHarmonicsFilter(11, transformation) + + for x in meas_x: + sh_filter.update_nonlinear(lambda z, x: norm.pdf(z[0], x[0], sigma_x), array([x, 0.0, 0.0])) + for y in meas_y: + sh_filter.update_nonlinear(lambda z, x: norm.pdf(z[1], x[1], sigma_y), array([0.0, y, 0.0])) + for z in meas_z: + sh_filter.update_nonlinear(lambda z, x: norm.pdf(z[2], x[2], sigma_z), array([0.0, 0.0, z])) + + npt.assert_allclose(sh_filter.get_estimate_mean(), pos_true, atol=0.3) + + @_skip_jax + def test_update_using_likelihood_multiple(self): + sigma_x = 0.3 + sigma_y = 0.3 + sigma_z = 0.3 + for transformation in ['identity', 'sqrt']: + sh_filter1 = SphericalHarmonicsFilter(10, transformation) + sh_filter2 = SphericalHarmonicsFilter(10, transformation) + + sh_filter1.update_nonlinear(lambda z, x: norm.pdf(z[0], x[0], sigma_x), array([-1.0 / np.sqrt(3), 0.0, 0.0])) + sh_filter1.update_nonlinear(lambda z, x: norm.pdf(z[1], x[1], sigma_y), array([0.0, -1.0 / np.sqrt(3), 0.0])) + sh_filter1.update_nonlinear(lambda z, x: norm.pdf(z[2], x[2], sigma_z), array([0.0, 0.0, -1.0 / np.sqrt(3)])) + + sh_filter2.update_nonlinear_multiple( + [ + lambda z, x: norm.pdf(z[0], x[0], sigma_x), + lambda z, x: norm.pdf(z[1], x[1], sigma_y), + lambda z, x: norm.pdf(z[2], x[2], sigma_z) + ], + [ + array([-1.0 / np.sqrt(3), 0.0, 0.0]), + array([0.0, -1.0 / np.sqrt(3), 0.0]), + array([0.0, 0.0, -1.0 / np.sqrt(3)]) + ] + ) + + npt.assert_allclose(sh_filter2.get_estimate_mean(), sh_filter1.get_estimate_mean(), atol=1e-5) + + @_skip_jax + def test_prediction_sqrt_vs_id(self): + degree = 21 + density_init = VonMisesFisherDistribution(array([1.0, 1.0, 0.0]) / np.sqrt(2), 2) + sys_noise = VonMisesFisherDistribution(array([0.0, 0.0, 1.0]), 1) + + shd_init_id = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(density_init, degree, 'identity') + shd_init_sqrt = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(density_init, degree, 'sqrt') + shd_noise_id = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(sys_noise, degree, 'identity') + shd_noise_sqrt = SphericalHarmonicsDistributionComplex.from_distribution_numerical_fast(sys_noise, degree, 'sqrt') + + sh_filter1 = SphericalHarmonicsFilter(degree, 'identity') + sh_filter2 = SphericalHarmonicsFilter(degree, 'sqrt') + + sh_filter1.set_state(shd_init_id) + sh_filter2.set_state(shd_init_sqrt) + + sh_filter1.predict_identity(shd_noise_id) + sh_filter2.predict_identity(shd_noise_sqrt) + + np.testing.assert_allclose(sh_filter1.state.total_variation_distance_numerical(sh_filter2.state), 0, atol=5e-15) + + # Helper function to generate truncated normals + def _generate_truncated_normals(self, mu, sigma, n): + samples = [] + while len(samples) < n: + sample = np.random.normal(mu, sigma) + if -1 <= sample <= 1: + samples.append(sample) + return samples