-
Notifications
You must be signed in to change notification settings - Fork 54
Closed
Labels
API extensionAdds new functions or objects to the API.Adds new functions or objects to the API.
Milestone
Description
This RFC proposes adding support to the array API specification for finding the indices where elements should be inserted in order to maintain order.
Overview
Based on array comparison data, the API is available across all considered libraries.
Furthermore, all considered libraries support the side keyword argument, and all considered libraries, except for TensorFlow, support the sorter keyword argument.
JAX supports an additional kwarg, method, which is used based on device/size performance optimization considerations. PyTorch and TensorFlow support specifying the output data type, but differ in naming conventions.
All array libraries support one-dimensional arrays. PyTorch and TensorFlow generalize to any n-dimensional ndarray (stacking).
Prior Art
- NumPy: https://numpy.org/doc/stable/reference/generated/numpy.searchsorted.html
- PyTorch: https://pytorch.org/docs/stable/generated/torch.searchsorted.html
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/searchsorted
- JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.searchsorted.html
- CuPy: https://docs.cupy.dev/en/stable/reference/generated/cupy.searchsorted.html
- Dask: https://docs.dask.org/en/stable/generated/dask.array.searchsorted.html
Proposal
def searchsorted(x1: array, x2: array, /, *, side="left", sorter=None)- x1: one-dimensional array. If
sorterisNone,x1must be sorted in ascending order. - x2: one-dimensional array.
- side: if "left", the returned index
isatisfiesx1[i-1] < x2[j] <= x1[i]. Otherwise, if "right", the returned indexisatisfiesx1[i-1] <= x2[j] < x1[i]. If no suitable index, theniis either0orN, respectively, wereNis the length ofx1. - sorter: array of integer indices that sort
x1in ascending order (e.g., as might be produced viaargsort).
Questions
- Should the API be extended to support stacking as in PyTorch/TensorFlow?
- Should the API support a scalar value for
x2? NumPy, PyTorch, JAX, Dask support scalars. CuPy and TensorFlow do not.
Metadata
Metadata
Assignees
Labels
API extensionAdds new functions or objects to the API.Adds new functions or objects to the API.
Type
Projects
Status
Stage 2