Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions pyrecest/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,21 @@
PartiallyWrappedNormalDistribution,
)
from .cart_prod.se2_bingham_distribution import SE2BinghamDistribution
from .cart_prod.se2_pwn_distribution import SE2PWNDistribution
from .cart_prod.state_space_subdivision_distribution import (
StateSpaceSubdivisionDistribution,
)
from .cart_prod.state_space_subdivision_gaussian_distribution import (
StateSpaceSubdivisionGaussianDistribution,
)
from .cart_prod.se2_pwn_distribution import SE2PWNDistribution
from .circle.abstract_circular_distribution import AbstractCircularDistribution
from .circle.circular_dirac_distribution import CircularDiracDistribution
from .circle.circular_fourier_distribution import CircularFourierDistribution
from .circle.circular_mixture import CircularMixture
from .circle.circular_uniform_distribution import CircularUniformDistribution
from .circle.custom_circular_distribution import CustomCircularDistribution
from .circle.generalized_von_mises_distribution import GvMDistribution
from .circle.piecewise_constant_distribution import PiecewiseConstantDistribution
from .circle.sine_skewed_distributions import (
AbstractSineSkewedDistribution,
GeneralizedKSineSkewedVonMisesDistribution,
Expand All @@ -97,7 +98,6 @@
SineSkewedWrappedCauchyDistribution,
SineSkewedWrappedNormalDistribution,
)
from .circle.piecewise_constant_distribution import PiecewiseConstantDistribution
from .circle.von_mises_distribution import VonMisesDistribution
from .circle.wrapped_cauchy_distribution import WrappedCauchyDistribution
from .circle.wrapped_laplace_distribution import WrappedLaplaceDistribution
Expand All @@ -111,18 +111,6 @@
from .custom_hyperrectangular_distribution import CustomHyperrectangularDistribution
from .disk_uniform_distribution import DiskUniformDistribution
from .ellipsoidal_ball_uniform_distribution import EllipsoidalBallUniformDistribution
from .nonperiodic.abstract_hyperrectangular_distribution import (
AbstractHyperrectangularDistribution,
)
from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution
from .nonperiodic.custom_linear_distribution import CustomLinearDistribution
from .nonperiodic.gaussian_distribution import GaussianDistribution
from .nonperiodic.gaussian_mixture import GaussianMixture
from .nonperiodic.hyperrectangular_uniform_distribution import (
HyperrectangularUniformDistribution,
)
from .nonperiodic.linear_dirac_distribution import LinearDiracDistribution
from .nonperiodic.linear_mixture import LinearMixture
from .hypersphere_subset.abstract_hemispherical_distribution import (
AbstractHemisphericalDistribution,
)
Expand Down Expand Up @@ -172,15 +160,12 @@
from .hypersphere_subset.hyperhemispherical_dirac_distribution import (
HyperhemisphericalDiracDistribution,
)
from .hypersphere_subset.hyperhemispherical_uniform_distribution import (
HyperhemisphericalUniformDistribution,
)
from .hypersphere_subset.hyperspherical_uniform_distribution import (
HypersphericalUniformDistribution,
)
from .hypersphere_subset.hyperhemispherical_grid_distribution import (
HyperhemisphericalGridDistribution,
)
from .hypersphere_subset.hyperhemispherical_uniform_distribution import (
HyperhemisphericalUniformDistribution,
)
from .hypersphere_subset.hyperhemispherical_watson_distribution import (
HyperhemisphericalWatsonDistribution,
)
Expand All @@ -191,6 +176,9 @@
HypersphericalGridDistribution,
)
from .hypersphere_subset.hyperspherical_mixture import HypersphericalMixture
from .hypersphere_subset.hyperspherical_uniform_distribution import (
HypersphericalUniformDistribution,
)
from .hypersphere_subset.spherical_grid_distribution import SphericalGridDistribution
from .hypersphere_subset.spherical_harmonics_distribution_complex import (
SphericalHarmonicsDistributionComplex,
Expand Down Expand Up @@ -237,6 +225,18 @@
from .hypertorus.toroidal_wrapped_normal_distribution import (
ToroidalWrappedNormalDistribution,
)
from .nonperiodic.abstract_hyperrectangular_distribution import (
AbstractHyperrectangularDistribution,
)
from .nonperiodic.abstract_linear_distribution import AbstractLinearDistribution
from .nonperiodic.custom_linear_distribution import CustomLinearDistribution
from .nonperiodic.gaussian_distribution import GaussianDistribution
from .nonperiodic.gaussian_mixture import GaussianMixture
from .nonperiodic.hyperrectangular_uniform_distribution import (
HyperrectangularUniformDistribution,
)
from .nonperiodic.linear_dirac_distribution import LinearDiracDistribution
from .nonperiodic.linear_mixture import LinearMixture
from .se2_dirac_distribution import SE2DiracDistribution
from .se3_cart_prod_stacked_distribution import SE3CartProdStackedDistribution
from .se3_dirac_distribution import SE3DiracDistribution
Expand Down
40 changes: 31 additions & 9 deletions pyrecest/distributions/circle/piecewise_constant_distribution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# pylint: disable=no-name-in-module,no-member,redefined-builtin
import pyrecest.backend
from pyrecest.backend import mod, pi, array, arange, mean, floor, zeros, exp, sum, log, random
from pyrecest.backend import (
arange,
array,
exp,
floor,
log,
mean,
mod,
pi,
random,
sum,
zeros,
)

from .abstract_circular_distribution import AbstractCircularDistribution

Expand Down Expand Up @@ -48,7 +60,10 @@ def pdf(self, xs):
n_intervals = len(self.w)
xs_mod = array(mod(xs, 2.0 * pi), dtype=float)
idx = array(
[min(int(floor(x / (2.0 * pi) * n_intervals)), n_intervals - 1) for x in xs_mod]
[
min(int(floor(x / (2.0 * pi) * n_intervals)), n_intervals - 1)
for x in xs_mod
]
)
return self.w[idx]

Expand All @@ -66,7 +81,9 @@ def trigonometric_moment(self, n):
n-th trigonometric moment.
"""
if pyrecest.backend.__backend_name__ == "jax": # pylint: disable=no-member
raise NotImplementedError("trigonometric_moment is not supported on the JAX backend.")
raise NotImplementedError(
"trigonometric_moment is not supported on the JAX backend."
)
if n == 0:
return 1.0 + 0j
num = len(self.w)
Expand Down Expand Up @@ -111,8 +128,13 @@ def sample(self, n):
# construction. Divide by sum anyway to guard against floating-point drift.
interval_probs = self.w * interval_width
interval_probs /= interval_probs.sum()
interval_indices = random.choice(arange(num_intervals), size=(n,), p=interval_probs)
return interval_indices * interval_width + random.uniform(size=(n,)) * interval_width
interval_indices = random.choice(
arange(num_intervals), size=(n,), p=interval_probs
)
return (
interval_indices * interval_width
+ random.uniform(size=(n,)) * interval_width
)

@staticmethod
def left_border(m, n):
Expand Down Expand Up @@ -190,14 +212,14 @@ def calculate_parameters_numerically(pdf_func, n):
from scipy.integrate import quad # pylint: disable=import-outside-toplevel

if pyrecest.backend.__backend_name__ == "jax": # pylint: disable=no-member
raise NotImplementedError("calculate_parameters_numerically is not supported on the JAX backend.")
raise NotImplementedError(
"calculate_parameters_numerically is not supported on the JAX backend."
)

assert n >= 1
w = zeros(n)
for j in range(1, n + 1):
left = PiecewiseConstantDistribution.left_border(j, n)
r = PiecewiseConstantDistribution.right_border(j, n)
w[j - 1] = quad(
lambda x: float(pdf_func(array([x]))), left, r
)[0]
w[j - 1] = quad(lambda x: float(pdf_func(array([x]))), left, r)[0]
return w
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# pylint: disable=redefined-builtin,no-name-in-module,no-member
from scipy.integrate import quad
from scipy.optimize import fsolve
from scipy.special import iv

from pyrecest.backend import (
abs,
all,
Expand All @@ -15,12 +11,15 @@
linalg,
max,
maximum,
pi,
ones,
pi,
sort,
sum,
zeros,
sort,
)
from scipy.integrate import quad
from scipy.optimize import fsolve
from scipy.special import iv

from .abstract_hyperspherical_distribution import AbstractHypersphericalDistribution

Expand Down
10 changes: 6 additions & 4 deletions pyrecest/distributions/se2_dirac_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
)


class SE2DiracDistribution(
HypercylindricalDiracDistribution, AbstractSE2Distribution
):
class SE2DiracDistribution(HypercylindricalDiracDistribution, AbstractSE2Distribution):
"""Partially wrapped Dirac distribution on SE(2).

Represents a distribution on SE(2) = S^1 x R^2 using weighted Dirac
Expand Down Expand Up @@ -56,7 +54,11 @@ def covariance_4d(self):
-------
array of shape (4, 4)
"""
from pyrecest.backend import column_stack, cos, sin # pylint: disable=import-outside-toplevel
from pyrecest.backend import ( # pylint: disable=import-outside-toplevel
column_stack,
cos,
sin,
)

S = column_stack(
(cos(self.d[:, 0:1]), sin(self.d[:, 0:1]), self.d[:, 1:])
Expand Down
6 changes: 3 additions & 3 deletions pyrecest/filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from .abstract_multitarget_tracker import AbstractMultitargetTracker
from .abstract_nearest_neighbor_tracker import AbstractNearestNeighborTracker
from .abstract_particle_filter import AbstractParticleFilter
from .bingham_filter import BinghamFilter
from .circular_ukf import CircularUKF
from .abstract_tracker_with_logging import AbstractTrackerWithLogging
from .bingham_filter import BinghamFilter
from .circular_particle_filter import CircularParticleFilter
from .circular_ukf import CircularUKF
from .euclidean_particle_filter import EuclideanParticleFilter
from .hypercylindrical_particle_filter import HypercylindricalParticleFilter
from .global_nearest_neighbor import GlobalNearestNeighbor
from .gprhm_tracker import GPRHMTracker
from .hypercylindrical_particle_filter import HypercylindricalParticleFilter
from .hyperhemisphere_cart_prod_particle_filter import (
HyperhemisphereCartProdParticleFilter,
)
Expand Down
32 changes: 22 additions & 10 deletions pyrecest/filters/circular_ukf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(self, alpha: float = 1e-3, beta: float = 2.0, kappa: float = 0.0):
self._alpha = alpha
self._beta = beta
self._kappa = kappa
initial_state = GaussianDistribution(
array([0.0]), array([[1.0]])
)
initial_state = GaussianDistribution(array([0.0]), array([[1.0]]))
CircularFilterMixin.__init__(self)
AbstractFilter.__init__(self, initial_state)

Expand Down Expand Up @@ -147,8 +145,16 @@ def hx(x):
return x

ukf = _make_ukf(
fx, hx, dim_z=1, x0=mu0, P0=C0, Q=Q_val, R=array([[C0]]),
alpha=self._alpha, beta=self._beta, kappa=self._kappa,
fx,
hx,
dim_z=1,
x0=mu0,
P0=C0,
Q=Q_val,
R=array([[C0]]),
alpha=self._alpha,
beta=self._beta,
kappa=self._kappa,
)
ukf.predict()

Expand Down Expand Up @@ -194,9 +200,7 @@ def update_identity(self, gauss_meas: GaussianDistribution, z):
new_C = (1.0 - K) * C

new_mu = float(mod(array([new_mu]), 2.0 * pi)[0])
self._filter_state = GaussianDistribution(
array([new_mu]), array([[new_C]])
)
self._filter_state = GaussianDistribution(array([new_mu]), array([[new_C]]))

def update_nonlinear( # pylint: disable=too-many-locals
self, f, gauss_meas: GaussianDistribution, z, measurement_periodic: bool = False
Expand Down Expand Up @@ -251,8 +255,16 @@ def hx(x):
return atleast_1d(array([f(x.flatten()[0])], dtype=float))

ukf = _make_ukf(
fx, hx, dim_z=dim_z, x0=mu0, P0=C0, Q=0.0, R=R_mat,
alpha=self._alpha, beta=self._beta, kappa=self._kappa,
fx,
hx,
dim_z=dim_z,
x0=mu0,
P0=C0,
Q=0.0,
R=R_mat,
alpha=self._alpha,
beta=self._beta,
kappa=self._kappa,
)
# predict() with identity fx and Q=0 populates sigmas_f without
# altering the mean or covariance, which is required before update().
Expand Down
4 changes: 3 additions & 1 deletion pyrecest/filters/hypercylindrical_particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from .manifold_mixins import HypercylindricalFilterMixin


class HypercylindricalParticleFilter(AbstractParticleFilter, HypercylindricalFilterMixin):
class HypercylindricalParticleFilter(
AbstractParticleFilter, HypercylindricalFilterMixin
):
def __init__(
self,
n_particles: Union[int, int32, int64],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pyrecest.backend

# pylint: disable=no-name-in-module,no-member,redefined-builtin
from pyrecest.backend import linspace, sum, exp, log, mean, pi, array
from pyrecest.backend import array, exp, linspace, log, mean, pi, sum
from pyrecest.distributions.circle.piecewise_constant_distribution import (
PiecewiseConstantDistribution,
)
Expand Down Expand Up @@ -33,9 +33,7 @@ def test_pdf(self):
def test_integral_normalized(self):
"""Verify the distribution integrates to 1 via the exact sum."""
n = len(self.dist.w)
npt.assert_allclose(
sum(self.dist.w) * (2.0 * pi / n), 1.0, rtol=5e-7
)
npt.assert_allclose(sum(self.dist.w) * (2.0 * pi / n), 1.0, rtol=5e-7)

def test_integral_partial(self):
"""Verify partial integrals sum to 1 using a fine grid."""
Expand Down
4 changes: 3 additions & 1 deletion pyrecest/tests/distributions/test_se2_dirac_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def test_sampling(self):
s = self.dist.sample(n)
self.assertEqual(s.shape, (n, 3))
# Angles should be in [0, 2*pi)
from pyrecest.backend import all as backend_all # pylint: disable=import-outside-toplevel
from pyrecest.backend import (
all as backend_all, # pylint: disable=import-outside-toplevel
)

self.assertTrue(backend_all(s[:, 0] >= 0))
self.assertTrue(backend_all(s[:, 0] < 2 * pi))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def setUp(self):
def test_initialization(self):
hpf = HypercylindricalParticleFilter(10, self.bound_dim, self.lin_dim)
self.assertIsNotNone(hpf.filter_state)
self.assertEqual(
hpf.filter_state.d.shape, (10, self.bound_dim + self.lin_dim)
)
self.assertEqual(hpf.filter_state.d.shape, (10, self.bound_dim + self.lin_dim))

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "jax", reason="Backend not supported"
Expand Down
Loading