Skip to content

Commit 182a6d1

Browse files
committed
address review comments
1 parent 7993474 commit 182a6d1

File tree

2 files changed

+108
-86
lines changed

2 files changed

+108
-86
lines changed

dpnp/dpnp_iface_logic.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,8 @@ def isin(
11711171
test_elements,
11721172
assume_unique=False, # pylint: disable=unused-argument
11731173
invert=False,
1174+
*,
1175+
kind=None, # pylint: disable=unused-argument
11741176
):
11751177
"""
11761178
Calculates ``element in test_elements``, broadcasting over `element` only.
@@ -1190,22 +1192,27 @@ def isin(
11901192
assume_unique : bool, optional
11911193
Ignored, as no performance benefit is gained by assuming the
11921194
input arrays are unique. Included for compatibility with NumPy.
1195+
11931196
Default: ``False``.
11941197
invert : bool, optional
11951198
If ``True``, the values in the returned array are inverted, as if
1196-
calculating `element not in test_elements`.
1199+
calculating ``element not in test_elements``.
11971200
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
11981201
than) ``dpnp.invert(dpnp.isin(a, b))``.
1202+
11991203
Default: ``False``.
1204+
kind : {None, "sort"}, optional
1205+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1206+
compatibility with NumPy.
12001207
1208+
Default: ``None``.
12011209
12021210
Returns
12031211
-------
12041212
isin : dpnp.ndarray of bool dtype
12051213
Has the same shape as `element`. The values `element[isin]`
12061214
are in `test_elements`.
12071215
1208-
12091216
Examples
12101217
--------
12111218
>>> import dpnp as np
@@ -1238,14 +1245,32 @@ def isin(
12381245
"""
12391246

12401247
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1241-
usm_element = dpnp.as_usm_ndarray(
1242-
element, usm_type=element.usm_type, sycl_queue=element.sycl_queue
1243-
)
1244-
usm_test = dpnp.as_usm_ndarray(
1245-
test_elements,
1246-
usm_type=test_elements.usm_type,
1247-
sycl_queue=test_elements.sycl_queue,
1248-
)
1248+
if dpnp.isscalar(element):
1249+
usm_element = dpnp.as_usm_ndarray(
1250+
element,
1251+
usm_type=test_elements.usm_type,
1252+
sycl_queue=test_elements.sycl_queue,
1253+
)
1254+
usm_test = dpnp.get_usm_ndarray(test_elements)
1255+
elif dpnp.isscalar(test_elements):
1256+
usm_test = dpnp.as_usm_ndarray(
1257+
test_elements,
1258+
usm_type=element.usm_type,
1259+
sycl_queue=element.sycl_queue,
1260+
)
1261+
usm_element = dpnp.get_usm_ndarray(element)
1262+
else:
1263+
if (
1264+
dpu.get_execution_queue(
1265+
(element.sycl_queue, test_elements.sycl_queue)
1266+
)
1267+
is None
1268+
):
1269+
raise dpu.ExecutionPlacementError(
1270+
"Input arrays have incompatible allocation queues"
1271+
)
1272+
usm_element = dpnp.get_usm_ndarray(element)
1273+
usm_test = dpnp.get_usm_ndarray(test_elements)
12491274
return dpnp.get_result_array(
12501275
dpt.isin(
12511276
usm_element,

dpnp/tests/test_logic.py

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -797,98 +797,95 @@ def test_array_equal_nan(a):
797797
assert_equal(result, expected)
798798

799799

800-
@pytest.mark.parametrize(
801-
"a",
802-
[
803-
numpy.array([1, 2, 3, 4]),
804-
numpy.array([[1, 2], [3, 4]]),
805-
],
806-
)
807-
@pytest.mark.parametrize(
808-
"b",
809-
[
810-
numpy.array([2, 4, 6]),
811-
numpy.array([[1, 3], [5, 7]]),
812-
],
813-
)
814-
def test_isin_basic(a, b):
815-
dp_a = dpnp.array(a)
816-
dp_b = dpnp.array(b)
817-
818-
expected = numpy.isin(a, b)
819-
result = dpnp.isin(dp_a, dp_b)
820-
assert_equal(result, expected)
821-
822-
823-
@pytest.mark.parametrize("dtype", get_all_dtypes())
824-
def test_isin_dtype(dtype):
825-
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826-
b = numpy.array([2, 4], dtype=dtype)
827-
828-
dp_a = dpnp.array(a, dtype=dtype)
829-
dp_b = dpnp.array(b, dtype=dtype)
830-
831-
expected = numpy.isin(a, b)
832-
result = dpnp.isin(dp_a, dp_b)
833-
assert_equal(result, expected)
834-
800+
class TestIsin:
801+
@pytest.mark.parametrize(
802+
"a",
803+
[
804+
numpy.array([1, 2, 3, 4]),
805+
numpy.array([[1, 2], [3, 4]]),
806+
],
807+
)
808+
@pytest.mark.parametrize(
809+
"b",
810+
[
811+
numpy.array([2, 4, 6]),
812+
numpy.array([[1, 3], [5, 7]]),
813+
],
814+
)
815+
def test_isin_basic(self, a, b):
816+
dp_a = dpnp.array(a)
817+
dp_b = dpnp.array(b)
835818

836-
@pytest.mark.parametrize("sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))])
837-
def test_isin_broadcast(sh_a, sh_b):
838-
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
839-
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
819+
expected = numpy.isin(a, b)
820+
result = dpnp.isin(dp_a, dp_b)
821+
assert_equal(result, expected)
840822

841-
dp_a = dpnp.array(a)
842-
dp_b = dpnp.array(b)
823+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
824+
def test_isin_dtype(self, dtype):
825+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
826+
b = numpy.array([2, 4], dtype=dtype)
843827

844-
expected = numpy.isin(a, b)
845-
result = dpnp.isin(dp_a, dp_b)
846-
assert_equal(result, expected)
828+
dp_a = dpnp.array(a, dtype=dtype)
829+
dp_b = dpnp.array(b, dtype=dtype)
847830

831+
expected = numpy.isin(a, b)
832+
result = dpnp.isin(dp_a, dp_b)
833+
assert_equal(result, expected)
848834

849-
def test_isin_scalar_elements():
850-
a = numpy.array([1, 2, 3])
851-
b = 2
835+
@pytest.mark.parametrize(
836+
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
837+
)
838+
def test_isin_broadcast(self, sh_a, sh_b):
839+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
840+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
852841

853-
dp_a = dpnp.array(a)
854-
dp_b = dpnp.array(b)
842+
dp_a = dpnp.array(a)
843+
dp_b = dpnp.array(b)
855844

856-
expected = numpy.isin(a, b)
857-
result = dpnp.isin(dp_a, dp_b)
858-
assert_equal(result, expected)
845+
expected = numpy.isin(a, b)
846+
result = dpnp.isin(dp_a, dp_b)
847+
assert_equal(result, expected)
859848

849+
def test_isin_scalar_elements(self):
850+
a = numpy.array([1, 2, 3])
851+
b = 2
860852

861-
def test_isin_scalar_test_elements():
862-
a = 2
863-
b = numpy.array([1, 2, 3])
853+
dp_a = dpnp.array(a)
854+
dp_b = dpnp.array(b)
864855

865-
dp_a = dpnp.array(a)
866-
dp_b = dpnp.array(b)
856+
expected = numpy.isin(a, b)
857+
result = dpnp.isin(dp_a, dp_b)
858+
assert_equal(result, expected)
867859

868-
expected = numpy.isin(a, b)
869-
result = dpnp.isin(dp_a, dp_b)
870-
assert_equal(result, expected)
860+
def test_isin_scalar_test_elements(self):
861+
a = 2
862+
b = numpy.array([1, 2, 3])
871863

864+
dp_a = dpnp.array(a)
865+
dp_b = dpnp.array(b)
872866

873-
def test_isin_empty():
874-
a = numpy.array([], dtype=int)
875-
b = numpy.array([1, 2, 3])
867+
expected = numpy.isin(a, b)
868+
result = dpnp.isin(dp_a, dp_b)
869+
assert_equal(result, expected)
876870

877-
dp_a = dpnp.array(a)
878-
dp_b = dpnp.array(b)
871+
def test_isin_empty(self):
872+
a = numpy.array([], dtype=int)
873+
b = numpy.array([1, 2, 3])
879874

880-
expected = numpy.isin(a, b)
881-
result = dpnp.isin(dp_a, dp_b)
882-
assert_equal(result, expected)
875+
dp_a = dpnp.array(a)
876+
dp_b = dpnp.array(b)
883877

878+
expected = numpy.isin(a, b)
879+
result = dpnp.isin(dp_a, dp_b)
880+
assert_equal(result, expected)
884881

885-
def test_isin_errors():
886-
a = dpnp.arange(5)
887-
b = dpnp.arange(3)
882+
def test_isin_errors(self):
883+
a = dpnp.arange(5)
884+
b = dpnp.arange(3)
888885

889-
# unsupported type for elements or test_elements
890-
with pytest.raises(TypeError):
891-
dpnp.isin(dict(), b)
886+
# unsupported type for elements or test_elements
887+
with pytest.raises(TypeError):
888+
dpnp.isin(dict(), b)
892889

893-
with pytest.raises(TypeError):
894-
dpnp.isin(a, dict())
890+
with pytest.raises(TypeError):
891+
dpnp.isin(a, dict())

0 commit comments

Comments
 (0)