diff --git a/numpyro/_typing.py b/numpyro/_typing.py index b717ef616..d53d60c17 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -6,20 +6,15 @@ from collections.abc import Callable from typing import ( Any, - Optional, ParamSpec, - Protocol, TypeAlias, TypeVar, Union, - runtime_checkable, ) -import weakref import numpy as np import jax -from jax.typing import ArrayLike P = ParamSpec("P") ModelT: TypeAlias = Callable[P, Any] @@ -41,43 +36,3 @@ NumLikeT = TypeVar("NumLikeT", bound=NumLike) - - -@runtime_checkable -class ConstraintT(Protocol): - """A protocol for typing constraints.""" - - @property - def is_discrete(self) -> bool: ... - @property - def event_dim(self) -> int: ... - - def __call__(self, x: NumLike) -> ArrayLike: ... - def __repr__(self) -> str: ... - def check(self, value: NumLike) -> ArrayLike: ... - def feasible_like(self, prototype: NumLike) -> NumLike: ... - - -@runtime_checkable -class TransformT(Protocol): - _inv: Optional[Union["TransformT", weakref.ref]] = ... - - @property - def domain(self) -> ConstraintT: ... - @property - def codomain(self) -> ConstraintT: ... - @property - def inv(self) -> "TransformT": ... - @property - def sign(self) -> NumLike: ... - - def __call__(self, x: NumLike) -> NumLike: ... - def _inverse(self, y: NumLike) -> NumLike: ... - def log_abs_det_jacobian( - self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None - ) -> NumLike: ... - def call_with_intermediates( - self, x: NumLike - ) -> tuple[NumLike, Optional[PyTree]]: ... - def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... - def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ... diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py index ff313c576..81e18bd37 100644 --- a/numpyro/distributions/censored.py +++ b/numpyro/distributions/censored.py @@ -12,8 +12,8 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT from numpyro.distributions import constraints +from numpyro.distributions.constraints import Constraint from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample from numpyro.util import find_stack_level, not_jax_tracer @@ -116,7 +116,7 @@ def sample( return self.base_dist.expand(self.batch_shape).sample(key, sample_shape) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @validate_sample @@ -232,7 +232,7 @@ def sample( return self.base_dist.expand(self.batch_shape).sample(key, sample_shape) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @validate_sample @@ -367,7 +367,7 @@ def sample( return self.base_dist.expand(self.batch_shape).sample(key, sample_shape) @constraints.dependent_property(is_discrete=False, event_dim=1) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support def _get_censoring_masks(self, value): diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index fca86e000..23447b99e 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -10,8 +10,8 @@ from jax.scipy.special import betainc, betaln, gammaln from jax.typing import ArrayLike -from numpyro._typing import ConstraintT from numpyro.distributions import constraints +from numpyro.distributions.constraints import Constraint from numpyro.distributions.continuous import Beta, Dirichlet, Gamma from numpyro.distributions.discrete import ( BinomialProbs, @@ -105,7 +105,7 @@ def variance(self) -> ArrayLike: ) @constraints.dependent_property(is_discrete=True, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return constraints.integer_interval(0, self.total_count) @@ -324,7 +324,7 @@ def variance(self) -> ArrayLike: return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum) @constraints.dependent_property(is_discrete=True, event_dim=1) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return constraints.multinomial(self.total_count) @staticmethod diff --git a/numpyro/distributions/copula.py b/numpyro/distributions/copula.py index b3aad194c..278131c04 100644 --- a/numpyro/distributions/copula.py +++ b/numpyro/distributions/copula.py @@ -8,8 +8,8 @@ from jax import Array, lax, numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT import numpyro.distributions.constraints as constraints +from numpyro.distributions.constraints import Constraint from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample @@ -100,7 +100,7 @@ def variance(self) -> ArrayLike: return jnp.broadcast_to(self.marginal_dist.variance, self.shape()) @constraints.dependent_property(is_discrete=False, event_dim=1) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return constraints.independent(self.marginal_dist.support, 1) @lazy_property diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index 62eef556a..8294f68c5 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -6,7 +6,6 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import TransformT from numpyro.distributions.constraints import real_vector from numpyro.distributions.transforms import Transform from numpyro.util import fori_loop @@ -109,7 +108,7 @@ def tree_flatten(self): {"arn": self.arn}, ) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: Transform) -> bool: if not isinstance(other, InverseAutoregressiveTransform): return False return ( @@ -170,7 +169,7 @@ def log_abs_det_jacobian( def tree_flatten(self): return (), ((), {"bn_arn": self.bn_arn}) - def __eq__(self, other: TransformT) -> bool: + def __eq__(self, other: Transform) -> bool: return ( isinstance(other, BlockNeuralAutoregressiveTransform) and self.bn_arn is other.bn_arn diff --git a/numpyro/distributions/mixtures.py b/numpyro/distributions/mixtures.py index de55f3088..3caa828b4 100644 --- a/numpyro/distributions/mixtures.py +++ b/numpyro/distributions/mixtures.py @@ -9,8 +9,8 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT from numpyro.distributions import constraints +from numpyro.distributions.constraints import Constraint from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import validate_sample @@ -258,7 +258,7 @@ def component_distribution(self) -> Distribution: return self._component_distribution @constraints.dependent_property - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self.component_distribution.support @property @@ -353,7 +353,7 @@ def __init__( mixing_distribution: Union[CategoricalProbs, CategoricalLogits], component_distributions: list[Distribution], *, - support: Optional[ConstraintT] = None, + support: Optional[Constraint] = None, validate_args: Optional[bool] = None, ): _check_mixing_distribution(mixing_distribution) @@ -424,7 +424,7 @@ def component_distributions(self) -> list[Distribution]: return self._component_distributions @constraints.dependent_property - def support(self) -> ConstraintT: + def support(self) -> Constraint: if self._support is not None: return self._support return self.component_distributions[0].support diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index b0e133e39..630be0b00 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -11,8 +11,8 @@ from jax.scipy.special import logsumexp from jax.typing import ArrayLike -from numpyro._typing import ConstraintT from numpyro.distributions import constraints +from numpyro.distributions.constraints import Constraint from numpyro.distributions.continuous import ( Cauchy, Laplace, @@ -57,7 +57,7 @@ def __init__( super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @lazy_property @@ -162,7 +162,7 @@ def __init__( super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @lazy_property @@ -259,7 +259,7 @@ def __init__( super().__init__(batch_shape, validate_args=validate_args) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @lazy_property @@ -529,7 +529,7 @@ def __init__( ) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @validate_sample @@ -1010,7 +1010,7 @@ def __init__( ) @constraints.dependent_property(is_discrete=False, event_dim=0) - def support(self) -> ConstraintT: + def support(self) -> Constraint: return self._support @validate_sample