11import math
22import warnings
33from types import ModuleType
4- from typing import Any , cast
4+ from typing import Any , Literal , cast
55
66import hypothesis
77import hypothesis .extra .numpy as npst
@@ -1645,10 +1645,12 @@ def test_kind(self, xp: ModuleType, library: Backend):
16451645 xp_assert_equal (res , expected )
16461646
16471647
1648- def _apply_over_batch (* argdefs : tuple [str , int ]):
1648+ def _apply_over_batch (* argdefs : tuple [str , int ]) -> Any :
16491649 """
16501650 Factory for decorator that applies a function over batched arguments.
16511651
1652+ Copied (with light simplifications) from `scipy._lib._util`.
1653+
16521654 Array arguments may have any number of core dimensions (typically 0,
16531655 1, or 2) and any broadcastable batch shapes. There may be any
16541656 number of array outputs of any number of dimensions. Assumptions
@@ -1675,8 +1677,11 @@ def _apply_over_batch(*argdefs: tuple[str, int]):
16751677 names , ndims = list (zip (* argdefs , strict = True ))
16761678 n_arrays = len (names )
16771679
1678- def decorator (f ):
1679- def wrapper (* args_tuple , ** kwargs ):
1680+ def decorator (f : Any ) -> Any :
1681+ def wrapper (
1682+ * args_tuple : tuple [Any ] | None ,
1683+ ** kwargs : dict [str , Any ] | None ,
1684+ ) -> Any :
16801685 args = list (args_tuple )
16811686
16821687 # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs`
@@ -1688,9 +1693,9 @@ def wrapper(*args_tuple, **kwargs):
16881693 f"{ f .__name__ } () got multiple values for argument `{ name } `."
16891694 )
16901695 raise ValueError (message )
1691- arrays .append (kwargs .pop (name ))
1696+ arrays .append (kwargs .pop (name )) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
16921697
1693- xp = array_namespace (* arrays )
1698+ xp = array_namespace (* arrays ) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
16941699
16951700 # Determine core and batch shapes
16961701 batch_shapes = []
@@ -1751,8 +1756,13 @@ def wrapper(*args_tuple, **kwargs):
17511756 return decorator
17521757
17531758
1754- @_apply_over_batch (("a" , 1 ), ("v" , 1 ))
1755- def xp_searchsorted (a , v , side , xp ):
1759+ @_apply_over_batch (("a" , 1 ), ("v" , 1 )) # type: ignore[misc]
1760+ def xp_searchsorted (
1761+ a : Array ,
1762+ v : Array ,
1763+ side : Literal ["left" , "right" ],
1764+ xp : ModuleType ,
1765+ ) -> Array :
17561766 return xp .searchsorted (xp .asarray (a ), xp .asarray (v ), side = side )
17571767
17581768
@@ -1766,13 +1776,21 @@ class TestSearchsorted:
17661776 )
17671777 @pytest .mark .parametrize ("nans_x" , [False , True ])
17681778 @pytest .mark .parametrize ("infs_x" , [False , True ])
1769- def test_nd (self , side , ties , shape , nans_x , infs_x , xp ):
1779+ def test_nd (
1780+ self ,
1781+ side : Literal ["left" , "right" ],
1782+ ties : bool ,
1783+ shape : int | tuple [int ],
1784+ nans_x : bool ,
1785+ infs_x : bool ,
1786+ xp : ModuleType ,
1787+ ):
17701788 if nans_x and is_torch_namespace (xp ):
17711789 pytest .skip ("torch sorts NaNs differently" )
17721790 rng = np .random .default_rng (945298725498274853 )
17731791 x = rng .integers (5 , size = shape ) if ties else rng .random (shape )
17741792 # float32 is to accommodate JAX - nextafter with `float64` is too small?
1775- x = np .asarray (x , dtype = np .float32 )
1793+ x = np .asarray (x , dtype = np .float32 ) # type:ignore[assignment]
17761794 xr = np .nextafter (x , np .inf )
17771795 xl = np .nextafter (x , - np .inf )
17781796 x_ = np .asarray ([- np .inf , np .inf , np .nan ])
@@ -1786,7 +1804,7 @@ def test_nd(self, side, ties, shape, nans_x, infs_x, xp):
17861804 x [mask ] = - np .inf
17871805 mask = rng .random (shape ) > 0.9
17881806 x [mask ] = np .inf
1789- x = np .sort (x , stable = True , axis = - 1 )
1807+ x = np .sort (x , axis = - 1 ) # type:ignore[assignment]
17901808 x , y = np .asarray (x , dtype = np .float64 ), np .asarray (y , dtype = np .float64 )
17911809 xp_default_int = xp .asarray (1 ).dtype
17921810 if x .size == 0 and x .ndim > 0 and x .shape [- 1 ] != 0 :
0 commit comments