diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 14a3803b..bb165602 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -22,6 +22,7 @@ default_dtype, kron, nunique, + searchsorted, ) from ._lib._lazy import lazy_apply @@ -48,6 +49,7 @@ "one_hot", "pad", "partition", + "searchsorted", "setdiff1d", "sinc", ] diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 6e50ce95..3895014d 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -8,7 +8,12 @@ from ._at import at from ._utils import _compat, _helpers -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array +from ._utils._compat import ( + array_namespace, + is_dask_namespace, + is_jax_array, + is_torch_namespace, +) from ._utils._helpers import ( asarrays, capabilities, @@ -28,6 +33,7 @@ "kron", "nunique", "pad", + "searchsorted", "setdiff1d", "sinc", ] @@ -665,6 +671,95 @@ def pad( return at(padded, tuple(slices)).set(x) +def searchsorted( + x1: Array, + x2: Array, + /, + *, + side: Literal["left", "right"] = "left", + xp: ModuleType | None = None, +) -> Array: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array ``x1`` such that if the elements in ``x2`` + were inserted before the indices, the resulting array would remain sorted. + + Parameters + ---------- + x1 : Array + Input array. Should have a real-valued data type. Must be sorted in ascending + order along the last axis. + x2 : Array + Array containing search values. Should have a real-valued data type. Must have + the same shape as ``x1`` except along the last axis. + side : {'left', 'right'}, optional + Argument controlling which index is returned if an element of ``x2`` is equal to + one or more elements of ``x1``: ``'left'`` returns the index of the first of + these elements; ``'right'`` returns the next index after the last of these + elements. Default: ``'left'``. + xp : array_namespace, optional + The standard-compatible namespace for the array arguments. Default: infer. + + Returns + ------- + Array: integer array + An array of indices with the same shape as ``x2``. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([11, 12, 13, 13, 14, 15]) + >>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp) + Array([0, 1, 5, 6], dtype=array_api_strict.int64) + >>> xpx.searchsorted(x, xp.asarray(13), xp=xp) + Array(2, dtype=array_api_strict.int64) + >>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp) + Array(4, dtype=array_api_strict.int64) + + `searchsorted` is vectorized along the last axis. + + >>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]]) + >>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]]) + >>> xpx.searchsorted(x1, x2, xp=xp) + Array([[1, 3], + [2, 4]], dtype=array_api_strict.int64) + """ + xp = array_namespace(x1, x2) if xp is None else xp + xp_default_int = xp.asarray(1).dtype + y_0d = xp.asarray(x2).ndim == 0 + x_1d = x1.ndim <= 1 + + if x_1d or is_torch_namespace(xp): + x2 = xp.reshape(x2, ()) if (y_0d and x_1d) else x2 + out = xp.searchsorted(x1, x2, side=side) + return xp.astype(out, xp_default_int, copy=False) + + a = xp.full(x2.shape, 0, device=_compat.device(x1)) + + if x1.shape[-1] == 0: + return a + + n = xp.count_nonzero(~xp.isnan(x1), axis=-1, keepdims=True) + b = xp.broadcast_to(n, x2.shape) + + compare = xp.less_equal if side == "left" else xp.less + + # while xp.any(b - a > 1): + # refactored to for loop with ~log2(n) iterations for JAX JIT + for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + c = (a + b) // 2 + x0 = xp.take_along_axis(x1, c, axis=-1) + j = compare(x2, x0) + b = xp.where(j, c, b) + a = xp.where(j, a, c) + + out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b) + out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out + return xp.astype(out, xp_default_int, copy=False) + + def setdiff1d( x1: Array | complex, x2: Array | complex, diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6b10757f..f455541c 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,7 +1,7 @@ import math import warnings from types import ModuleType -from typing import Any, cast +from typing import Any, Literal, cast import hypothesis import hypothesis.extra.numpy as npst @@ -29,13 +29,18 @@ one_hot, pad, partition, + searchsorted, setdiff1d, sinc, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal +from array_api_extra._lib._utils._compat import ( + array_namespace, + is_jax_namespace, + is_torch_namespace, +) from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._compat import is_jax_namespace from array_api_extra._lib._utils._helpers import eager_shape, ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -52,6 +57,7 @@ lazy_xp_function(pad) # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) +lazy_xp_function(searchsorted) lazy_xp_function(sinc) NestedFloatList = list[float] | list["NestedFloatList"] @@ -1637,3 +1643,175 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +def _apply_over_batch(*argdefs: tuple[str, int]) -> Any: + """ + Factory for decorator that applies a function over batched arguments. + + Copied (with light simplifications) from `scipy._lib._util`. + + Array arguments may have any number of core dimensions (typically 0, + 1, or 2) and any broadcastable batch shapes. There may be any + number of array outputs of any number of dimensions. Assumptions + right now - which are satisfied by all functions of interest in `linalg` - + are that all array inputs are consecutive keyword or positional arguments, + and that the wrapped function returns either a single array or a tuple of + arrays. It's only as general as it needs to be right now - it can be extended. + + Parameters + ---------- + *argdefs : tuple of (str, int) + Definitions of array arguments: the keyword name of the argument, and + the number of core dimensions. + + Example: + -------- + `linalg.eig` accepts two matrices as the first two arguments `a` and `b`, where + `b` is optional, and returns one array or a tuple of arrays, depending on the + values of other positional or keyword arguments. To generate a wrapper that applies + the function over batches of `a` and optionally `b` : + + >>> _apply_over_batch(('a', 2), ('b', 2)) + """ + names, ndims = list(zip(*argdefs, strict=True)) + n_arrays = len(names) + + def decorator(f: Any) -> Any: + def wrapper( + *args_tuple: tuple[Any] | None, + **kwargs: dict[str, Any] | None, + ) -> Any: + args = list(args_tuple) + + # Ensure all arrays in `arrays`, other arguments in `other_args`/`kwargs` + arrays, other_args = args[:n_arrays], args[n_arrays:] + for i, name in enumerate(names): + if name in kwargs: + if i + 1 <= len(args): + message = ( + f"{f.__name__}() got multiple values for argument `{name}`." + ) + raise ValueError(message) + arrays.append(kwargs.pop(name)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + xp = array_namespace(*arrays) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + + # Determine core and batch shapes + batch_shapes = [] + core_shapes = [] + for i, (array, ndim) in enumerate(zip(arrays, ndims, strict=True)): + array = None if array is None else xp.asarray(array) # noqa: PLW2901 + shape = () if array is None else array.shape + arrays[i] = array + batch_shapes.append(shape[:-ndim] if ndim > 0 else shape) + core_shapes.append(shape[-ndim:] if ndim > 0 else ()) + + # Early exit if call is not batched + if not any(batch_shapes): + return f(*arrays, *other_args, **kwargs) + + # Determine broadcasted batch shape + batch_shape = np.broadcast_shapes(*batch_shapes) # Gives OK error message + + # Broadcast arrays to appropriate shape + for i, (array, core_shape) in enumerate( + zip(arrays, core_shapes, strict=True) + ): + if array is None: + continue + arrays[i] = xp.broadcast_to(array, batch_shape + core_shape) + + # Main loop + results = [] + for index in np.ndindex(batch_shape): + result = f( + *( + (array[index] if array is not None else None) + for array in arrays + ), + *other_args, + **kwargs, + ) + # Assume `result` is either a tuple or single array. This is easily + # generalized by allowing the contributor to pass an `unpack_result` + # callable to the decorator factory. + result = (result,) if not isinstance(result, tuple) else result + results.append(result) + results = list(zip(*results, strict=True)) + + # Reshape results + for i, result in enumerate(results): + result = xp.stack(result) # noqa: PLW2901 + core_shape = result.shape[1:] + results[i] = xp.reshape(result, batch_shape + core_shape) + + # Assume `result` should be a single array if there is only one element or + # a `tuple` otherwise. This is easily generalized by allowing the + # contributor to pass an `pack_result` callable to the decorator factory. + return results[0] if len(results) == 1 else results + + return wrapper + + return decorator + + +@_apply_over_batch(("a", 1), ("v", 1)) # type: ignore[misc] +def xp_searchsorted( + a: Array, + v: Array, + side: Literal["left", "right"], + xp: ModuleType, +) -> Array: + return xp.searchsorted(xp.asarray(a), xp.asarray(v), side=side) + + +@pytest.mark.skip_xp_backend(Backend.DASK, reason="no take_along_axis") +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no searchsorted") +class TestSearchsorted: + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("ties", [False, True]) + @pytest.mark.parametrize( + "shape", [0, 1, 2, 10, 11, 1000, 10001, (2, 0), (0, 2), (2, 10), (2, 3, 11)] + ) + @pytest.mark.parametrize("nans_x", [False, True]) + @pytest.mark.parametrize("infs_x", [False, True]) + def test_nd( + self, + side: Literal["left", "right"], + ties: bool, + shape: int | tuple[int], + nans_x: bool, + infs_x: bool, + xp: ModuleType, + ): + if nans_x and is_torch_namespace(xp): + pytest.skip("torch sorts NaNs differently") + rng = np.random.default_rng(945298725498274853) + x = rng.integers(5, size=shape) if ties else rng.random(shape) + # float32 is to accommodate JAX - nextafter with `float64` is too small? + x = np.asarray(x, dtype=np.float32) # type:ignore[assignment] + xr = np.nextafter(x, np.inf) + xl = np.nextafter(x, -np.inf) + x_ = np.asarray([-np.inf, np.inf, np.nan]) + x_ = np.broadcast_to(x_, (*x.shape[:-1], 3)) + y = rng.permuted(np.concatenate((xl, x, xr, x_), axis=-1), axis=-1) + if nans_x: + mask = rng.random(shape) < 0.1 + x[mask] = np.nan + if infs_x: + mask = rng.random(shape) < 0.1 + x[mask] = -np.inf + mask = rng.random(shape) > 0.9 + x[mask] = np.inf + x = np.sort(x, axis=-1) # type:ignore[assignment] + x, y = np.asarray(x, dtype=np.float64), np.asarray(y, dtype=np.float64) + xp_default_int = xp.asarray(1).dtype + if x.size == 0 and x.ndim > 0 and x.shape[-1] != 0: + ref = xp.empty((*x.shape[:-1], y.shape[-1]), dtype=xp_default_int) + else: + ref = xp_searchsorted(x, y, side=side, xp=np) + ref = xp.asarray(ref, dtype=xp_default_int) + x, y = xp.asarray(x.copy()), xp.asarray(y.copy()) + res = searchsorted(x, y, side=side, xp=xp) + xp_assert_equal(res, ref)