|
32 | 32 | searchsorted, |
33 | 33 | setdiff1d, |
34 | 34 | sinc, |
| 35 | + union1d, |
35 | 36 | ) |
36 | 37 | from array_api_extra._lib._backends import NUMPY_VERSION, Backend |
37 | 38 | from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal |
|
60 | 61 | lazy_xp_function(searchsorted) |
61 | 62 | lazy_xp_function(sinc) |
62 | 63 |
|
63 | | -NestedFloatList = list[float] | list["NestedFloatList"] |
64 | | - |
65 | 64 |
|
66 | 65 | class TestApplyWhere: |
67 | 66 | @staticmethod |
@@ -716,6 +715,7 @@ def test_0d_raises(self, xp: ModuleType): |
716 | 715 | (0, 1), |
717 | 716 | (1, 0), |
718 | 717 | (0, 0), |
| 718 | + (2, 3), |
719 | 719 | (4, 2, 1), |
720 | 720 | (1, 1, 7), |
721 | 721 | (0, 0, 1), |
@@ -1815,3 +1815,36 @@ def test_nd( |
1815 | 1815 | x, y = xp.asarray(x.copy()), xp.asarray(y.copy()) |
1816 | 1816 | res = searchsorted(x, y, side=side, xp=xp) |
1817 | 1817 | xp_assert_equal(res, ref) |
| 1818 | + |
| 1819 | + |
| 1820 | +@pytest.mark.skip_xp_backend( |
| 1821 | + Backend.ARRAY_API_STRICTEST, |
| 1822 | + reason="data_dependent_shapes flag for unique_values is disabled", |
| 1823 | +) |
| 1824 | +class TestUnion1d: |
| 1825 | + def test_simple(self, xp: ModuleType): |
| 1826 | + a = xp.asarray([-1, 1, 0]) |
| 1827 | + b = xp.asarray([2, -2, 0]) |
| 1828 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1829 | + res = union1d(a, b) |
| 1830 | + xp_assert_equal(res, expected) |
| 1831 | + |
| 1832 | + def test_2d(self, xp: ModuleType): |
| 1833 | + a = xp.asarray([[-1, 1, 0], [1, 2, 0]]) |
| 1834 | + b = xp.asarray([[1, 0, 1], [-2, -1, 0]]) |
| 1835 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1836 | + res = union1d(a, b) |
| 1837 | + xp_assert_equal(res, expected) |
| 1838 | + |
| 1839 | + def test_3d(self, xp: ModuleType): |
| 1840 | + a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]]) |
| 1841 | + b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]]) |
| 1842 | + expected = xp.asarray([-2, -1, 0, 1, 2]) |
| 1843 | + res = union1d(a, b) |
| 1844 | + xp_assert_equal(res, expected) |
| 1845 | + |
| 1846 | + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") |
| 1847 | + def test_device(self, xp: ModuleType, device: Device): |
| 1848 | + a = xp.asarray([-1, 1, 0], device=device) |
| 1849 | + b = xp.asarray([2, -2, 0], device=device) |
| 1850 | + assert get_device(union1d(a, b)) == device |
0 commit comments