Skip to content

Commit a094ffe

Browse files
committed
STY: searchsorted: fix typing issues
1 parent 9c811a2 commit a094ffe

File tree

2 files changed

+31
-13
lines changed

2 files changed

+31
-13
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def searchsorted(
677677
/,
678678
*,
679679
side: Literal["left", "right"] = "left",
680-
xp: ModuleType,
680+
xp: ModuleType | None = None,
681681
) -> Array:
682682
"""
683683
Find indices where elements should be inserted to maintain order.
@@ -748,7 +748,7 @@ def searchsorted(
748748

749749
# while xp.any(b - a > 1):
750750
# refactored to for loop with ~log2(n) iterations for JAX JIT
751-
for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type]
751+
for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
752752
c = (a + b) // 2
753753
x0 = xp.take_along_axis(x1, c, axis=-1)
754754
j = compare(x2, x0)

tests/test_funcs.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import warnings
33
from types import ModuleType
4-
from typing import Any, cast
4+
from typing import Any, Literal, cast
55

66
import hypothesis
77
import hypothesis.extra.numpy as npst
@@ -1645,10 +1645,12 @@ def test_kind(self, xp: ModuleType, library: Backend):
16451645
xp_assert_equal(res, expected)
16461646

16471647

1648-
def _apply_over_batch(*argdefs: tuple[str, int]):
1648+
def _apply_over_batch(*argdefs: tuple[str, int]) -> Any:
16491649
"""
16501650
Factory for decorator that applies a function over batched arguments.
16511651
1652+
Copied (with light simplifications) from `scipy._lib._util`.
1653+
16521654
Array arguments may have any number of core dimensions (typically 0,
16531655
1, or 2) and any broadcastable batch shapes. There may be any
16541656
number of array outputs of any number of dimensions. Assumptions
@@ -1675,8 +1677,11 @@ def _apply_over_batch(*argdefs: tuple[str, int]):
16751677
names, ndims = list(zip(*argdefs, strict=True))
16761678
n_arrays = len(names)
16771679

1678-
def decorator(f):
1679-
def wrapper(*args_tuple, **kwargs):
1680+
def decorator(f: Any) -> Any:
1681+
def wrapper(
1682+
*args_tuple: tuple[Any] | None,
1683+
**kwargs: dict[str, Any] | None,
1684+
) -> Any:
16801685
args = list(args_tuple)
16811686

16821687
# Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
@@ -1688,9 +1693,9 @@ def wrapper(*args_tuple, **kwargs):
16881693
f"{f.__name__}() got multiple values for argument `{name}`."
16891694
)
16901695
raise ValueError(message)
1691-
arrays.append(kwargs.pop(name))
1696+
arrays.append(kwargs.pop(name)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
16921697

1693-
xp = array_namespace(*arrays)
1698+
xp = array_namespace(*arrays) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
16941699

16951700
# Determine core and batch shapes
16961701
batch_shapes = []
@@ -1751,8 +1756,13 @@ def wrapper(*args_tuple, **kwargs):
17511756
return decorator
17521757

17531758

1754-
@_apply_over_batch(("a", 1), ("v", 1))
1755-
def xp_searchsorted(a, v, side, xp):
1759+
@_apply_over_batch(("a", 1), ("v", 1)) # type: ignore[misc]
1760+
def xp_searchsorted(
1761+
a: Array,
1762+
v: Array,
1763+
side: Literal["left", "right"],
1764+
xp: ModuleType,
1765+
) -> Array:
17561766
return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side)
17571767

17581768

@@ -1766,13 +1776,21 @@ class TestSearchsorted:
17661776
)
17671777
@pytest.mark.parametrize("nans_x", [False, True])
17681778
@pytest.mark.parametrize("infs_x", [False, True])
1769-
def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
1779+
def test_nd(
1780+
self,
1781+
side: Literal["left", "right"],
1782+
ties: bool,
1783+
shape: int | tuple[int],
1784+
nans_x: bool,
1785+
infs_x: bool,
1786+
xp: ModuleType,
1787+
):
17701788
if nans_x and is_torch_namespace(xp):
17711789
pytest.skip("torch sorts NaNs differently")
17721790
rng = np.random.default_rng(945298725498274853)
17731791
x = rng.integers(5, size=shape) if ties else rng.random(shape)
17741792
# float32 is to accommodate JAX - nextafter with `float64` is too small?
1775-
x = np.asarray(x, dtype=np.float32)
1793+
x = np.asarray(x, dtype=np.float32) # type:ignore[assignment]
17761794
xr = np.nextafter(x, np.inf)
17771795
xl = np.nextafter(x, -np.inf)
17781796
x_ = np.asarray([-np.inf, np.inf, np.nan])
@@ -1786,7 +1804,7 @@ def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
17861804
x[mask] = -np.inf
17871805
mask = rng.random(shape) > 0.9
17881806
x[mask] = np.inf
1789-
x = np.sort(x, stable=True, axis=-1)
1807+
x = np.sort(x, axis=-1) # type:ignore[assignment]
17901808
x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64)
17911809
xp_default_int = xp.asarray(1).dtype
17921810
if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0:

0 commit comments

Comments
 (0)