@@ -30,7 +30,10 @@ def matmul(x1: Array, x2: Array, /) -> Array:
3030 if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
3131 raise TypeError ('Only numeric dtypes are allowed in matmul' )
3232
33- return Array ._new (np .matmul (x1 ._array , x2 ._array ))
33+ if x1 .device != x2 .device :
34+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
35+
36+ return Array ._new (np .matmul (x1 ._array , x2 ._array ), device = x1 .device )
3437
3538# Note: tensordot is the numpy top-level namespace but not in np.linalg
3639
@@ -41,14 +44,17 @@ def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int],
4144 if x1 .dtype not in _numeric_dtypes or x2 .dtype not in _numeric_dtypes :
4245 raise TypeError ('Only numeric dtypes are allowed in tensordot' )
4346
44- return Array ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ))
47+ if x1 .device != x2 .device :
48+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
49+
50+ return Array ._new (np .tensordot (x1 ._array , x2 ._array , axes = axes ), device = x1 .device )
4551
4652# Note: this function is new in the array API spec. Unlike transpose, it only
4753# transposes the last two axes.
4854def matrix_transpose (x : Array , / ) -> Array :
4955 if x .ndim < 2 :
5056 raise ValueError ("x must be at least 2-dimensional for matrix_transpose" )
51- return Array ._new (np .swapaxes (x ._array , - 1 , - 2 ))
57+ return Array ._new (np .swapaxes (x ._array , - 1 , - 2 ), device = x . device )
5258
5359# Note: vecdot is not in NumPy
5460def vecdot (x1 : Array , x2 : Array , / , * , axis : int = - 1 ) -> Array :
@@ -61,6 +67,9 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
6167 elif axis < min (- 1 , - x1 .ndim , - x2 .ndim ):
6268 raise ValueError ("axis is out of bounds for x1 and x2" )
6369
70+ if x1 .device != x2 .device :
71+ raise RuntimeError (f"Arrays from two different devices ({ x1 .device } and { x2 .device } ) can not be combined." )
72+
6473 # In versions of the standard prior to 2023.12, vecdot applied axis after
6574 # broadcasting. This is different from applying it before broadcasting
6675 # when axis is nonnegative. The below code keeps this behavior for
@@ -78,4 +87,4 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
7887 x2_ = np .moveaxis (x2_ , axis , - 1 )
7988
8089 res = x1_ [..., None , :] @ x2_ [..., None ]
81- return Array ._new (res [..., 0 , 0 ])
90+ return Array ._new (res [..., 0 , 0 ], device = x1 . device )
0 commit comments