Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions numpyro/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
from collections.abc import Callable
from typing import (
Any,
Optional,
ParamSpec,
Protocol,
TypeAlias,
TypeVar,
Union,
runtime_checkable,
)
import weakref

import numpy as np

import jax
from jax.typing import ArrayLike

P = ParamSpec("P")
ModelT: TypeAlias = Callable[P, Any]
Expand All @@ -41,43 +36,3 @@


NumLikeT = TypeVar("NumLikeT", bound=NumLike)


@runtime_checkable
class ConstraintT(Protocol):
"""A protocol for typing constraints."""

@property
def is_discrete(self) -> bool: ...
@property
def event_dim(self) -> int: ...

def __call__(self, x: NumLike) -> ArrayLike: ...
def __repr__(self) -> str: ...
def check(self, value: NumLike) -> ArrayLike: ...
def feasible_like(self, prototype: NumLike) -> NumLike: ...


@runtime_checkable
class TransformT(Protocol):
_inv: Optional[Union["TransformT", weakref.ref]] = ...

@property
def domain(self) -> ConstraintT: ...
@property
def codomain(self) -> ConstraintT: ...
@property
def inv(self) -> "TransformT": ...
@property
def sign(self) -> NumLike: ...

def __call__(self, x: NumLike) -> NumLike: ...
def _inverse(self, y: NumLike) -> NumLike: ...
def log_abs_det_jacobian(
self, x: NumLike, y: NumLike, intermediates: Optional[PyTree] = None
) -> NumLike: ...
def call_with_intermediates(
self, x: NumLike
) -> tuple[NumLike, Optional[PyTree]]: ...
def forward_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
def inverse_shape(self, shape: tuple[int, ...]) -> tuple[int, ...]: ...
8 changes: 4 additions & 4 deletions numpyro/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import log1mexp, promote_shapes, validate_sample
from numpyro.util import find_stack_level, not_jax_tracer
Expand Down Expand Up @@ -116,7 +116,7 @@ def sample(
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@validate_sample
Expand Down Expand Up @@ -232,7 +232,7 @@ def sample(
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@validate_sample
Expand Down Expand Up @@ -367,7 +367,7 @@ def sample(
return self.base_dist.expand(self.batch_shape).sample(key, sample_shape)

@constraints.dependent_property(is_discrete=False, event_dim=1)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

def _get_censoring_masks(self, value):
Expand Down
6 changes: 3 additions & 3 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from jax.scipy.special import betainc, betaln, gammaln
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, Dirichlet, Gamma
from numpyro.distributions.discrete import (
BinomialProbs,
Expand Down Expand Up @@ -105,7 +105,7 @@ def variance(self) -> ArrayLike:
)

@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return constraints.integer_interval(0, self.total_count)


Expand Down Expand Up @@ -324,7 +324,7 @@ def variance(self) -> ArrayLike:
return n * alpha_ratio * (1 - alpha_ratio) * (n + alpha_sum) / (1 + alpha_sum)

@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return constraints.multinomial(self.total_count)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/copula.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from jax import Array, lax, numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT
import numpyro.distributions.constraints as constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import Beta, MultivariateNormal, Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import clamp_probs, lazy_property, validate_sample
Expand Down Expand Up @@ -100,7 +100,7 @@ def variance(self) -> ArrayLike:
return jnp.broadcast_to(self.marginal_dist.variance, self.shape())

@constraints.dependent_property(is_discrete=False, event_dim=1)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return constraints.independent(self.marginal_dist.support, 1)

@lazy_property
Expand Down
5 changes: 2 additions & 3 deletions numpyro/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import TransformT
from numpyro.distributions.constraints import real_vector
from numpyro.distributions.transforms import Transform
from numpyro.util import fori_loop
Expand Down Expand Up @@ -109,7 +108,7 @@ def tree_flatten(self):
{"arn": self.arn},
)

def __eq__(self, other: TransformT) -> bool:
def __eq__(self, other: Transform) -> bool:
if not isinstance(other, InverseAutoregressiveTransform):
return False
return (
Expand Down Expand Up @@ -170,7 +169,7 @@ def log_abs_det_jacobian(
def tree_flatten(self):
return (), ((), {"bn_arn": self.bn_arn})

def __eq__(self, other: TransformT) -> bool:
def __eq__(self, other: Transform) -> bool:
return (
isinstance(other, BlockNeuralAutoregressiveTransform)
and self.bn_arn is other.bn_arn
Expand Down
8 changes: 4 additions & 4 deletions numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import jax.numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.discrete import CategoricalLogits, CategoricalProbs
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import validate_sample
Expand Down Expand Up @@ -258,7 +258,7 @@ def component_distribution(self) -> Distribution:
return self._component_distribution

@constraints.dependent_property
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self.component_distribution.support

@property
Expand Down Expand Up @@ -353,7 +353,7 @@ def __init__(
mixing_distribution: Union[CategoricalProbs, CategoricalLogits],
component_distributions: list[Distribution],
*,
support: Optional[ConstraintT] = None,
support: Optional[Constraint] = None,
validate_args: Optional[bool] = None,
):
_check_mixing_distribution(mixing_distribution)
Expand Down Expand Up @@ -424,7 +424,7 @@ def component_distributions(self) -> list[Distribution]:
return self._component_distributions

@constraints.dependent_property
def support(self) -> ConstraintT:
def support(self) -> Constraint:
if self._support is not None:
return self._support
return self.component_distributions[0].support
Expand Down
12 changes: 6 additions & 6 deletions numpyro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT
from numpyro.distributions import constraints
from numpyro.distributions.constraints import Constraint
from numpyro.distributions.continuous import (
Cauchy,
Laplace,
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
super().__init__(batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@lazy_property
Expand Down Expand Up @@ -162,7 +162,7 @@ def __init__(
super().__init__(batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@lazy_property
Expand Down Expand Up @@ -259,7 +259,7 @@ def __init__(
super().__init__(batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@lazy_property
Expand Down Expand Up @@ -529,7 +529,7 @@ def __init__(
)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@validate_sample
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def __init__(
)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
def support(self) -> Constraint:
return self._support

@validate_sample
Expand Down