@@ -25,8 +25,11 @@ def concat(
2525 # Note: Casting rules here are different from the np.concatenate default
2626 # (no for scalars with axis=None, no cross-kind casting)
2727 dtype = result_type (* arrays )
28+ if len ({a .device for a in arrays }) > 1 :
29+ raise ValueError ("concat inputs must all be on the same device" )
30+
2831 arrays = tuple (a ._array for a in arrays )
29- return Array ._new (np .concatenate (arrays , axis = axis , dtype = dtype ._np_dtype ))
32+ return Array ._new (np .concatenate (arrays , axis = axis , dtype = dtype ._np_dtype ), device = arrays [ 0 ]. device )
3033
3134
3235def expand_dims (x : Array , / , * , axis : int ) -> Array :
@@ -35,7 +38,7 @@ def expand_dims(x: Array, /, *, axis: int) -> Array:
3538
3639 See its docstring for more information.
3740 """
38- return Array ._new (np .expand_dims (x ._array , axis ))
41+ return Array ._new (np .expand_dims (x ._array , axis ), device = x . device )
3942
4043
4144def flip (x : Array , / , * , axis : Optional [Union [int , Tuple [int , ...]]] = None ) -> Array :
@@ -44,7 +47,7 @@ def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) ->
4447
4548 See its docstring for more information.
4649 """
47- return Array ._new (np .flip (x ._array , axis = axis ))
50+ return Array ._new (np .flip (x ._array , axis = axis ), device = x . device )
4851
4952@requires_api_version ('2023.12' )
5053def moveaxis (
@@ -58,7 +61,7 @@ def moveaxis(
5861
5962 See its docstring for more information.
6063 """
61- return Array ._new (np .moveaxis (x ._array , source , destination ))
64+ return Array ._new (np .moveaxis (x ._array , source , destination ), device = x . device )
6265
6366# Note: The function name is different here (see also matrix_transpose).
6467# Unlike transpose(), the axes argument is required.
@@ -68,7 +71,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
6871
6972 See its docstring for more information.
7073 """
71- return Array ._new (np .transpose (x ._array , axes ))
74+ return Array ._new (np .transpose (x ._array , axes ), device = x . device )
7275
7376@requires_api_version ('2023.12' )
7477def repeat (
@@ -94,7 +97,7 @@ def repeat(
9497 else :
9598 raise TypeError ("repeats must be an int or array" )
9699
97- return Array ._new (np .repeat (x ._array , repeats , axis = axis ))
100+ return Array ._new (np .repeat (x ._array , repeats , axis = axis ), device = x . device )
98101
99102# Note: the optional argument is called 'shape', not 'newshape'
100103def reshape (x : Array ,
@@ -117,7 +120,7 @@ def reshape(x: Array,
117120 if copy is False and not np .shares_memory (data , reshaped ):
118121 raise AttributeError ("Incompatible shape for in-place modification." )
119122
120- return Array ._new (reshaped )
123+ return Array ._new (reshaped , device = x . device )
121124
122125
123126def roll (
@@ -132,7 +135,7 @@ def roll(
132135
133136 See its docstring for more information.
134137 """
135- return Array ._new (np .roll (x ._array , shift , axis = axis ))
138+ return Array ._new (np .roll (x ._array , shift , axis = axis ), device = x . device )
136139
137140
138141def squeeze (x : Array , / , axis : Union [int , Tuple [int , ...]]) -> Array :
@@ -141,7 +144,7 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array:
141144
142145 See its docstring for more information.
143146 """
144- return Array ._new (np .squeeze (x ._array , axis = axis ))
147+ return Array ._new (np .squeeze (x ._array , axis = axis ), device = x . device )
145148
146149
147150def stack (arrays : Union [Tuple [Array , ...], List [Array ]], / , * , axis : int = 0 ) -> Array :
@@ -152,8 +155,10 @@ def stack(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: int = 0) ->
152155 """
153156 # Call result type here just to raise on disallowed type combinations
154157 result_type (* arrays )
158+ if len ({a .device for a in arrays }) > 1 :
159+ raise ValueError ("concat inputs must all be on the same device" )
155160 arrays = tuple (a ._array for a in arrays )
156- return Array ._new (np .stack (arrays , axis = axis ))
161+ return Array ._new (np .stack (arrays , axis = axis ), device = arrays [ 0 ]. device )
157162
158163
159164@requires_api_version ('2023.12' )
@@ -166,7 +171,7 @@ def tile(x: Array, repetitions: Tuple[int, ...], /) -> Array:
166171 # Note: NumPy allows repetitions to be an int or array
167172 if not isinstance (repetitions , tuple ):
168173 raise TypeError ("repetitions must be a tuple" )
169- return Array ._new (np .tile (x ._array , repetitions ))
174+ return Array ._new (np .tile (x ._array , repetitions ), device = x . device )
170175
171176# Note: this function is new
172177@requires_api_version ('2023.12' )
0 commit comments