We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0de0fa9 commit c70a886Copy full SHA for c70a886
pytensor/tensor/basic.py
@@ -2943,6 +2943,8 @@ def stack(tensors: Sequence["TensorLike"], axis: int = 0):
2943
):
2944
# In case there is direct scalar
2945
tensors = list(map(as_tensor_variable, tensors))
2946
+ if len(tensors) == 1:
2947
+ return atleast_1d(tensors[0])
2948
dtype = ps.upcast(*[i.dtype for i in tensors])
2949
return MakeVector(dtype)(*tensors)
2950
return join(axis, *[shape_padaxis(t, axis) for t in tensors])
0 commit comments