diff --git a/pyrecest/distributions/__init__.py b/pyrecest/distributions/__init__.py index 818450ca5..14475b091 100644 --- a/pyrecest/distributions/__init__.py +++ b/pyrecest/distributions/__init__.py @@ -74,13 +74,13 @@ PartiallyWrappedNormalDistribution, ) from .cart_prod.se2_bingham_distribution import SE2BinghamDistribution +from .cart_prod.se2_pwn_distribution import SE2PWNDistribution from .cart_prod.state_space_subdivision_distribution import ( StateSpaceSubdivisionDistribution, ) from .cart_prod.state_space_subdivision_gaussian_distribution import ( StateSpaceSubdivisionGaussianDistribution, ) -from .cart_prod.se2_pwn_distribution import SE2PWNDistribution from .circle.abstract_circular_distribution import AbstractCircularDistribution from .circle.circular_dirac_distribution import CircularDiracDistribution from .circle.circular_fourier_distribution import CircularFourierDistribution @@ -88,6 +88,7 @@ from .circle.circular_uniform_distribution import CircularUniformDistribution from .circle.custom_circular_distribution import CustomCircularDistribution from .circle.generalized_von_mises_distribution import GvMDistribution +from .circle.piecewise_constant_distribution import PiecewiseConstantDistribution from .circle.sine_skewed_distributions import ( AbstractSineSkewedDistribution, GeneralizedKSineSkewedVonMisesDistribution, @@ -97,7 +98,6 @@ SineSkewedWrappedCauchyDistribution, SineSkewedWrappedNormalDistribution, ) -from .circle.piecewise_constant_distribution import PiecewiseConstantDistribution from .circle.von_mises_distribution import VonMisesDistribution from .circle.wrapped_cauchy_distribution import WrappedCauchyDistribution from .circle.wrapped_laplace_distribution import WrappedLaplaceDistribution @@ -111,18 +111,6 @@ from .custom_hyperrectangular_distribution import CustomHyperrectangularDistribution from .disk_uniform_distribution import DiskUniformDistribution from .ellipsoidal_ball_uniform_distribution import EllipsoidalBallUniformDistribution -from .nonperiodic.abstract_hyperrectangular_distribution import ( - AbstractHyperrectangularDistribution, -) -from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution -from .nonperiodic.custom_linear_distribution import CustomLinearDistribution -from .nonperiodic.gaussian_distribution import GaussianDistribution -from .nonperiodic.gaussian_mixture import GaussianMixture -from .nonperiodic.hyperrectangular_uniform_distribution import ( - HyperrectangularUniformDistribution, -) -from .nonperiodic.linear_dirac_distribution import LinearDiracDistribution -from .nonperiodic.linear_mixture import LinearMixture from .hypersphere_subset.abstract_hemispherical_distribution import ( AbstractHemisphericalDistribution, ) @@ -172,15 +160,12 @@ from .hypersphere_subset.hyperhemispherical_dirac_distribution import ( HyperhemisphericalDiracDistribution, ) -from .hypersphere_subset.hyperhemispherical_uniform_distribution import ( - HyperhemisphericalUniformDistribution, -) -from .hypersphere_subset.hyperspherical_uniform_distribution import ( - HypersphericalUniformDistribution, -) from .hypersphere_subset.hyperhemispherical_grid_distribution import ( HyperhemisphericalGridDistribution, ) +from .hypersphere_subset.hyperhemispherical_uniform_distribution import ( + HyperhemisphericalUniformDistribution, +) from .hypersphere_subset.hyperhemispherical_watson_distribution import ( HyperhemisphericalWatsonDistribution, ) @@ -191,6 +176,9 @@ HypersphericalGridDistribution, ) from .hypersphere_subset.hyperspherical_mixture import HypersphericalMixture +from .hypersphere_subset.hyperspherical_uniform_distribution import ( + HypersphericalUniformDistribution, +) from .hypersphere_subset.spherical_grid_distribution import SphericalGridDistribution from .hypersphere_subset.spherical_harmonics_distribution_complex import ( SphericalHarmonicsDistributionComplex, @@ -237,6 +225,18 @@ from .hypertorus.toroidal_wrapped_normal_distribution import ( ToroidalWrappedNormalDistribution, ) +from .nonperiodic.abstract_hyperrectangular_distribution import ( + AbstractHyperrectangularDistribution, +) +from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution +from .nonperiodic.custom_linear_distribution import CustomLinearDistribution +from .nonperiodic.gaussian_distribution import GaussianDistribution +from .nonperiodic.gaussian_mixture import GaussianMixture +from .nonperiodic.hyperrectangular_uniform_distribution import ( + HyperrectangularUniformDistribution, +) +from .nonperiodic.linear_dirac_distribution import LinearDiracDistribution +from .nonperiodic.linear_mixture import LinearMixture from .se2_dirac_distribution import SE2DiracDistribution from .se3_cart_prod_stacked_distribution import SE3CartProdStackedDistribution from .se3_dirac_distribution import SE3DiracDistribution diff --git a/pyrecest/distributions/circle/piecewise_constant_distribution.py b/pyrecest/distributions/circle/piecewise_constant_distribution.py index 52da2bd10..7e78a8eb8 100644 --- a/pyrecest/distributions/circle/piecewise_constant_distribution.py +++ b/pyrecest/distributions/circle/piecewise_constant_distribution.py @@ -1,6 +1,18 @@ # pylint: disable=no-name-in-module,no-member,redefined-builtin import pyrecest.backend -from pyrecest.backend import mod, pi, array, arange, mean, floor, zeros, exp, sum, log, random +from pyrecest.backend import ( + arange, + array, + exp, + floor, + log, + mean, + mod, + pi, + random, + sum, + zeros, +) from .abstract_circular_distribution import AbstractCircularDistribution @@ -48,7 +60,10 @@ def pdf(self, xs): n_intervals = len(self.w) xs_mod = array(mod(xs, 2.0 * pi), dtype=float) idx = array( - [min(int(floor(x / (2.0 * pi) * n_intervals)), n_intervals - 1) for x in xs_mod] + [ + min(int(floor(x / (2.0 * pi) * n_intervals)), n_intervals - 1) + for x in xs_mod + ] ) return self.w[idx] @@ -66,7 +81,9 @@ def trigonometric_moment(self, n): n-th trigonometric moment. """ if pyrecest.backend.__backend_name__ == "jax": # pylint: disable=no-member - raise NotImplementedError("trigonometric_moment is not supported on the JAX backend.") + raise NotImplementedError( + "trigonometric_moment is not supported on the JAX backend." + ) if n == 0: return 1.0 + 0j num = len(self.w) @@ -111,8 +128,13 @@ def sample(self, n): # construction. Divide by sum anyway to guard against floating-point drift. interval_probs = self.w * interval_width interval_probs /= interval_probs.sum() - interval_indices = random.choice(arange(num_intervals), size=(n,), p=interval_probs) - return interval_indices * interval_width + random.uniform(size=(n,)) * interval_width + interval_indices = random.choice( + arange(num_intervals), size=(n,), p=interval_probs + ) + return ( + interval_indices * interval_width + + random.uniform(size=(n,)) * interval_width + ) @staticmethod def left_border(m, n): @@ -190,14 +212,14 @@ def calculate_parameters_numerically(pdf_func, n): from scipy.integrate import quad # pylint: disable=import-outside-toplevel if pyrecest.backend.__backend_name__ == "jax": # pylint: disable=no-member - raise NotImplementedError("calculate_parameters_numerically is not supported on the JAX backend.") + raise NotImplementedError( + "calculate_parameters_numerically is not supported on the JAX backend." + ) assert n >= 1 w = zeros(n) for j in range(1, n + 1): left = PiecewiseConstantDistribution.left_border(j, n) r = PiecewiseConstantDistribution.right_border(j, n) - w[j - 1] = quad( - lambda x: float(pdf_func(array([x]))), left, r - )[0] + w[j - 1] = quad(lambda x: float(pdf_func(array([x]))), left, r)[0] return w diff --git a/pyrecest/distributions/hypersphere_subset/bingham_distribution.py b/pyrecest/distributions/hypersphere_subset/bingham_distribution.py index cca2eb6fd..26c73e781 100644 --- a/pyrecest/distributions/hypersphere_subset/bingham_distribution.py +++ b/pyrecest/distributions/hypersphere_subset/bingham_distribution.py @@ -1,8 +1,4 @@ # pylint: disable=redefined-builtin,no-name-in-module,no-member -from scipy.integrate import quad -from scipy.optimize import fsolve -from scipy.special import iv - from pyrecest.backend import ( abs, all, @@ -15,12 +11,15 @@ linalg, max, maximum, - pi, ones, + pi, + sort, sum, zeros, - sort, ) +from scipy.integrate import quad +from scipy.optimize import fsolve +from scipy.special import iv from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution diff --git a/pyrecest/distributions/se2_dirac_distribution.py b/pyrecest/distributions/se2_dirac_distribution.py index dde56a47e..e4fa55b14 100644 --- a/pyrecest/distributions/se2_dirac_distribution.py +++ b/pyrecest/distributions/se2_dirac_distribution.py @@ -7,9 +7,7 @@ ) -class SE2DiracDistribution( - HypercylindricalDiracDistribution, AbstractSE2Distribution -): +class SE2DiracDistribution(HypercylindricalDiracDistribution, AbstractSE2Distribution): """Partially wrapped Dirac distribution on SE(2). Represents a distribution on SE(2) = S^1 x R^2 using weighted Dirac @@ -56,7 +54,11 @@ def covariance_4d(self): ------- array of shape (4, 4) """ - from pyrecest.backend import column_stack, cos, sin # pylint: disable=import-outside-toplevel + from pyrecest.backend import ( # pylint: disable=import-outside-toplevel + column_stack, + cos, + sin, + ) S = column_stack( (cos(self.d[:, 0:1]), sin(self.d[:, 0:1]), self.d[:, 1:]) diff --git a/pyrecest/filters/__init__.py b/pyrecest/filters/__init__.py index 5549ca524..9323c35ff 100644 --- a/pyrecest/filters/__init__.py +++ b/pyrecest/filters/__init__.py @@ -6,14 +6,14 @@ from .abstract_multitarget_tracker import AbstractMultitargetTracker from .abstract_nearest_neighbor_tracker import AbstractNearestNeighborTracker from .abstract_particle_filter import AbstractParticleFilter -from .bingham_filter import BinghamFilter -from .circular_ukf import CircularUKF from .abstract_tracker_with_logging import AbstractTrackerWithLogging +from .bingham_filter import BinghamFilter from .circular_particle_filter import CircularParticleFilter +from .circular_ukf import CircularUKF from .euclidean_particle_filter import EuclideanParticleFilter -from .hypercylindrical_particle_filter import HypercylindricalParticleFilter from .global_nearest_neighbor import GlobalNearestNeighbor from .gprhm_tracker import GPRHMTracker +from .hypercylindrical_particle_filter import HypercylindricalParticleFilter from .hyperhemisphere_cart_prod_particle_filter import ( HyperhemisphereCartProdParticleFilter, ) diff --git a/pyrecest/filters/circular_ukf.py b/pyrecest/filters/circular_ukf.py index 9458d0dd8..4668c5fc4 100644 --- a/pyrecest/filters/circular_ukf.py +++ b/pyrecest/filters/circular_ukf.py @@ -71,9 +71,7 @@ def __init__(self, alpha: float = 1e-3, beta: float = 2.0, kappa: float = 0.0): self._alpha = alpha self._beta = beta self._kappa = kappa - initial_state = GaussianDistribution( - array([0.0]), array([[1.0]]) - ) + initial_state = GaussianDistribution(array([0.0]), array([[1.0]])) CircularFilterMixin.__init__(self) AbstractFilter.__init__(self, initial_state) @@ -147,8 +145,16 @@ def hx(x): return x ukf = _make_ukf( - fx, hx, dim_z=1, x0=mu0, P0=C0, Q=Q_val, R=array([[C0]]), - alpha=self._alpha, beta=self._beta, kappa=self._kappa, + fx, + hx, + dim_z=1, + x0=mu0, + P0=C0, + Q=Q_val, + R=array([[C0]]), + alpha=self._alpha, + beta=self._beta, + kappa=self._kappa, ) ukf.predict() @@ -194,9 +200,7 @@ def update_identity(self, gauss_meas: GaussianDistribution, z): new_C = (1.0 - K) * C new_mu = float(mod(array([new_mu]), 2.0 * pi)[0]) - self._filter_state = GaussianDistribution( - array([new_mu]), array([[new_C]]) - ) + self._filter_state = GaussianDistribution(array([new_mu]), array([[new_C]])) def update_nonlinear( # pylint: disable=too-many-locals self, f, gauss_meas: GaussianDistribution, z, measurement_periodic: bool = False @@ -251,8 +255,16 @@ def hx(x): return atleast_1d(array([f(x.flatten()[0])], dtype=float)) ukf = _make_ukf( - fx, hx, dim_z=dim_z, x0=mu0, P0=C0, Q=0.0, R=R_mat, - alpha=self._alpha, beta=self._beta, kappa=self._kappa, + fx, + hx, + dim_z=dim_z, + x0=mu0, + P0=C0, + Q=0.0, + R=R_mat, + alpha=self._alpha, + beta=self._beta, + kappa=self._kappa, ) # predict() with identity fx and Q=0 populates sigmas_f without # altering the mean or covariance, which is required before update(). diff --git a/pyrecest/filters/hypercylindrical_particle_filter.py b/pyrecest/filters/hypercylindrical_particle_filter.py index 64065e4cf..b58fb83d2 100644 --- a/pyrecest/filters/hypercylindrical_particle_filter.py +++ b/pyrecest/filters/hypercylindrical_particle_filter.py @@ -11,7 +11,9 @@ from .manifold_mixins import HypercylindricalFilterMixin -class HypercylindricalParticleFilter(AbstractParticleFilter, HypercylindricalFilterMixin): +class HypercylindricalParticleFilter( + AbstractParticleFilter, HypercylindricalFilterMixin +): def __init__( self, n_particles: Union[int, int32, int64], diff --git a/pyrecest/tests/distributions/test_piecewise_constant_distribution.py b/pyrecest/tests/distributions/test_piecewise_constant_distribution.py index 76971a1cf..f1740906b 100644 --- a/pyrecest/tests/distributions/test_piecewise_constant_distribution.py +++ b/pyrecest/tests/distributions/test_piecewise_constant_distribution.py @@ -4,7 +4,7 @@ import pyrecest.backend # pylint: disable=no-name-in-module,no-member,redefined-builtin -from pyrecest.backend import linspace, sum, exp, log, mean, pi, array +from pyrecest.backend import array, exp, linspace, log, mean, pi, sum from pyrecest.distributions.circle.piecewise_constant_distribution import ( PiecewiseConstantDistribution, ) @@ -33,9 +33,7 @@ def test_pdf(self): def test_integral_normalized(self): """Verify the distribution integrates to 1 via the exact sum.""" n = len(self.dist.w) - npt.assert_allclose( - sum(self.dist.w) * (2.0 * pi / n), 1.0, rtol=5e-7 - ) + npt.assert_allclose(sum(self.dist.w) * (2.0 * pi / n), 1.0, rtol=5e-7) def test_integral_partial(self): """Verify partial integrals sum to 1 using a fine grid.""" diff --git a/pyrecest/tests/distributions/test_se2_dirac_distribution.py b/pyrecest/tests/distributions/test_se2_dirac_distribution.py index 5954ee1e5..7ba485f73 100644 --- a/pyrecest/tests/distributions/test_se2_dirac_distribution.py +++ b/pyrecest/tests/distributions/test_se2_dirac_distribution.py @@ -112,7 +112,9 @@ def test_sampling(self): s = self.dist.sample(n) self.assertEqual(s.shape, (n, 3)) # Angles should be in [0, 2*pi) - from pyrecest.backend import all as backend_all # pylint: disable=import-outside-toplevel + from pyrecest.backend import ( + all as backend_all, # pylint: disable=import-outside-toplevel + ) self.assertTrue(backend_all(s[:, 0] >= 0)) self.assertTrue(backend_all(s[:, 0] < 2 * pi)) diff --git a/pyrecest/tests/filters/test_hypercylindrical_particle_filter.py b/pyrecest/tests/filters/test_hypercylindrical_particle_filter.py index d54b475e4..dab640023 100644 --- a/pyrecest/tests/filters/test_hypercylindrical_particle_filter.py +++ b/pyrecest/tests/filters/test_hypercylindrical_particle_filter.py @@ -29,9 +29,7 @@ def setUp(self): def test_initialization(self): hpf = HypercylindricalParticleFilter(10, self.bound_dim, self.lin_dim) self.assertIsNotNone(hpf.filter_state) - self.assertEqual( - hpf.filter_state.d.shape, (10, self.bound_dim + self.lin_dim) - ) + self.assertEqual(hpf.filter_state.d.shape, (10, self.bound_dim + self.lin_dim)) @unittest.skipIf( pyrecest.backend.__backend_name__ == "jax", reason="Backend not supported"