Skip to content

Commit 699fc51

Browse files
authored
Merge branch 'main' into searchsorted
2 parents a094ffe + 2eb0d8e commit 699fc51

File tree

7 files changed

+202
-115
lines changed

7 files changed

+202
-115
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
partition
2626
setdiff1d
2727
sinc
28+
union1d
2829
```

pixi.lock

Lines changed: 102 additions & 95 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,22 @@ array-api-extra = { path = ".", editable = true }
5555

5656
[tool.pixi.feature.lint.dependencies]
5757
typing-extensions = ">=4.15.0"
58-
pylint = ">=4.0.2"
58+
pylint = ">=4.0.3"
5959
mypy = ">=1.18.2"
60-
basedpyright = ">=1.33.0"
60+
basedpyright = ">=1.34.0"
6161
numpydoc = ">=1.9.0,<2"
6262
# import dependencies for mypy:
6363
array-api-strict = ">=2.4.1,<2.5"
6464
numpy = ">=2.1.3"
65-
hypothesis = ">=6.142.4"
65+
hypothesis = ">=6.148.2"
6666
dask-core = ">=2025.11.0" # No distributed, tornado, etc.
6767
dprint = ">=0.50.0,<0.51"
68-
lefthook = ">=2.0.2,<3"
69-
ruff = ">=0.14.4,<0.15"
70-
typos = ">=1.39.0,<2"
71-
actionlint = ">=1.7.8,<2"
68+
lefthook = ">=2.0.4,<3"
69+
ruff = ">=0.14.6,<0.15"
70+
typos = ">=1.40.0,<2"
71+
actionlint = ">=1.7.9,<2"
7272
blacken-docs = ">=1.20.0,<2"
73-
pytest = ">=8.4.2,<9"
73+
pytest = ">=9.0.1,<10"
7474
validate-pyproject = ">=0.24.1,<0.25"
7575
# NOTE: don't add cupy, jax, pytorch, or sparse here,
7676
# as they slow down mypy and are not portable across target OSs
@@ -93,9 +93,9 @@ numpydoc = { cmd = "numpydoc lint", description = "Validate docstrings with nump
9393
lint = { cmd = "lefthook run pre-commit --all-files --force", description = "Run all linters" }
9494

9595
[tool.pixi.feature.tests.dependencies]
96-
pytest = ">=8.4.2"
96+
pytest = ">=9.0.1"
9797
pytest-cov = ">=7.0.0"
98-
hypothesis = ">=6.142.4"
98+
hypothesis = ">=6.148.2"
9999
array-api-strict = ">=2.4.1,<2.5"
100100
numpy = ">=1.22.0"
101101

@@ -121,7 +121,7 @@ sphinx-copybutton = ">=0.5.2"
121121
sphinx-autodoc-typehints = ">=1.25.3"
122122
# Needed to import parsed modules with autodoc
123123
dask-core = ">=2025.11.0" # No distributed, tornado, etc.
124-
pytest = ">=8.4.2"
124+
pytest = ">=9.0.1"
125125
typing-extensions = ">=4.15.0"
126126
numpy = ">=2.1.3"
127127

@@ -200,8 +200,8 @@ pytorch = { version = ">=2.8.0", build = "cuda12*" }
200200

201201
[tool.pixi.feature.nogil.dependencies]
202202
python-freethreading = "~=3.13.0"
203-
pytest-run-parallel = ">=0.7.1"
204-
numpy = ">=2.3.4"
203+
pytest-run-parallel = ">=0.8.0"
204+
numpy = ">=2.3.5"
205205
# pytorch = "*" # Not available on Python 3.13t yet
206206
dask-core = ">=2025.11.0" # No distributed, tornado, etc.
207207
# sparse = "*" # numba not available on Python 3.13t yet

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
partition,
1515
setdiff1d,
1616
sinc,
17+
union1d,
1718
)
1819
from ._lib._at import at
1920
from ._lib._funcs import (
@@ -52,4 +53,5 @@
5253
"searchsorted",
5354
"setdiff1d",
5455
"sinc",
56+
"union1d",
5557
]

src/array_api_extra/_delegation.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,14 @@ def create_diagonal(
226226
if is_torch_namespace(xp):
227227
return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1)
228228

229-
if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2:
229+
if (
230+
is_dask_namespace(xp)
231+
or is_cupy_namespace(xp)
232+
or is_numpy_namespace(xp)
233+
or is_jax_namespace(xp)
234+
) and (x.ndim < 2):
230235
return xp.diag(x, k=offset)
231236

232-
if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3:
233-
batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset)
234-
return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n))
235-
236237
return _funcs.create_diagonal(x, offset=offset, xp=xp)
237238

238239

@@ -1026,3 +1027,37 @@ def isin(
10261027
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)
10271028

10281029
return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)
1030+
1031+
1032+
def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
1033+
"""
1034+
Find the union of two arrays.
1035+
1036+
Return the unique, sorted array of values that are in either of the two
1037+
input arrays.
1038+
1039+
Parameters
1040+
----------
1041+
a, b : Array
1042+
Input arrays. They are flattened internally if they are not already 1D.
1043+
1044+
xp : array_namespace, optional
1045+
The standard-compatible namespace for `a` and `b`. Default: infer.
1046+
1047+
Returns
1048+
-------
1049+
Array
1050+
Unique, sorted union of the input arrays.
1051+
"""
1052+
if xp is None:
1053+
xp = array_namespace(a, b)
1054+
1055+
if (
1056+
is_numpy_namespace(xp)
1057+
or is_cupy_namespace(xp)
1058+
or is_dask_namespace(xp)
1059+
or is_jax_namespace(xp)
1060+
):
1061+
return xp.union1d(a, b)
1062+
1063+
return _funcs.union1d(a, b, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,12 @@ def isin( # numpydoc ignore=PR01,RT01
837837
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
838838
original_a_shape,
839839
)
840+
841+
842+
def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
843+
# numpydoc ignore=PR01,RT01
844+
"""See docstring in `array_api_extra._delegation.py`."""
845+
a = xp.reshape(a, (-1,))
846+
b = xp.reshape(b, (-1,))
847+
# XXX: `sparse` returns NumPy arrays from `unique_values`
848+
return xp.asarray(xp.unique_values(xp.concat([a, b])))

tests/test_funcs.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
searchsorted,
3333
setdiff1d,
3434
sinc,
35+
union1d,
3536
)
3637
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
3738
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
@@ -60,8 +61,6 @@
6061
lazy_xp_function(searchsorted)
6162
lazy_xp_function(sinc)
6263

63-
NestedFloatList = list[float] | list["NestedFloatList"]
64-
6564

6665
class TestApplyWhere:
6766
@staticmethod
@@ -716,6 +715,7 @@ def test_0d_raises(self, xp: ModuleType):
716715
(0, 1),
717716
(1, 0),
718717
(0, 0),
718+
(2, 3),
719719
(4, 2, 1),
720720
(1, 1, 7),
721721
(0, 0, 1),
@@ -1815,3 +1815,36 @@ def test_nd(
18151815
x, y = xp.asarray(x.copy()), xp.asarray(y.copy())
18161816
res = searchsorted(x, y, side=side, xp=xp)
18171817
xp_assert_equal(res, ref)
1818+
1819+
1820+
@pytest.mark.skip_xp_backend(
1821+
Backend.ARRAY_API_STRICTEST,
1822+
reason="data_dependent_shapes flag for unique_values is disabled",
1823+
)
1824+
class TestUnion1d:
1825+
def test_simple(self, xp: ModuleType):
1826+
a = xp.asarray([-1, 1, 0])
1827+
b = xp.asarray([2, -2, 0])
1828+
expected = xp.asarray([-2, -1, 0, 1, 2])
1829+
res = union1d(a, b)
1830+
xp_assert_equal(res, expected)
1831+
1832+
def test_2d(self, xp: ModuleType):
1833+
a = xp.asarray([[-1, 1, 0], [1, 2, 0]])
1834+
b = xp.asarray([[1, 0, 1], [-2, -1, 0]])
1835+
expected = xp.asarray([-2, -1, 0, 1, 2])
1836+
res = union1d(a, b)
1837+
xp_assert_equal(res, expected)
1838+
1839+
def test_3d(self, xp: ModuleType):
1840+
a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]])
1841+
b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]])
1842+
expected = xp.asarray([-2, -1, 0, 1, 2])
1843+
res = union1d(a, b)
1844+
xp_assert_equal(res, expected)
1845+
1846+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
1847+
def test_device(self, xp: ModuleType, device: Device):
1848+
a = xp.asarray([-1, 1, 0], device=device)
1849+
b = xp.asarray([2, -2, 0], device=device)
1850+
assert get_device(union1d(a, b)) == device

0 commit comments

Comments
 (0)