11"""
22# ===----------------------------------------------------------------------=== #
33# ARRAY MANIPULATION ROUTINES
4- # Last updated: 2024-07-21
4+ # Last updated: 2024-08-03
55# ===----------------------------------------------------------------------=== #
66"""
77
88
99fn copyto ():
1010 pass
1111
12+
1213fn ndim [dtype : DType](array : NDArray[dtype]) -> Int:
1314 """
1415 Returns the number of dimensions of the NDArray.
@@ -34,6 +35,7 @@ fn shape[dtype: DType](array: NDArray[dtype]) -> NDArrayShape:
3435 """
3536 return array.ndshape
3637
38+
3739fn size [dtype : DType](array : NDArray[dtype], axis : Int) raises -> Int:
3840 """
3941 Returns the size of the NDArray.
@@ -47,7 +49,12 @@ fn size[dtype: DType](array: NDArray[dtype], axis: Int) raises -> Int:
4749 """
4850 return array.ndshape[axis]
4951
50- fn reshape [dtype : DType](inout array : NDArray[dtype], shape : VariadicList[Int], order : String = " C" ) raises :
52+
53+ fn reshape [
54+ dtype : DType
55+ ](
56+ inout array : NDArray[dtype], shape : VariadicList[Int], order : String = " C"
57+ ) raises :
5158 """
5259 Reshapes the NDArray to given Shape.
5360
@@ -58,7 +65,7 @@ fn reshape[dtype: DType](inout array: NDArray[dtype], shape: VariadicList[Int],
5865 array: A NDArray.
5966 shape: Variadic integers of shape.
6067 order: Order of the array - Row major `C` or Column major `F`.
61-
68+
6269 """
6370 var num_elements_new : Int = 1
6471 var ndim_new : Int = 0
@@ -81,26 +88,32 @@ fn reshape[dtype: DType](inout array: NDArray[dtype], shape: VariadicList[Int],
8188 array.stride = NDArrayStride(shape = shape_new, order = order)
8289 array.order = order
8390
91+
8492fn ravel [dtype : DType](inout array : NDArray[dtype], order : String = " C" ) raises :
8593 """
8694 Returns the raveled version of the NDArray.
8795 """
8896 if array.ndim == 1 :
8997 print (" Array is already 1D" )
90- return
98+ return
9199 else :
92100 if order == " C" :
93101 reshape[dtype](array, array.ndshape._size, order = " C" )
94102 else :
95103 reshape[dtype](array, array.ndshape._size, order = " F" )
96104
97- fn where [dtype : DType](inout x : NDArray[dtype], scalar : SIMD [dtype, 1 ], mask :NDArray[DType.bool]) raises :
105+
106+ fn where [
107+ dtype : DType
108+ ](
109+ inout x : NDArray[dtype], scalar : SIMD [dtype, 1 ], mask : NDArray[DType.bool]
110+ ) raises :
98111 """
99112 Replaces elements in `x` with `scalar` where `mask` is True.
100113
101114 Parameters:
102115 dtype: DType.
103-
116+
104117 Args:
105118 x: A NDArray.
106119 scalar: A SIMD value.
@@ -111,8 +124,11 @@ fn where[dtype: DType](inout x: NDArray[dtype], scalar: SIMD[dtype, 1], mask:NDA
111124 if mask.data[i] == True :
112125 x.data.store(i, scalar)
113126
127+
114128# TODO : do it with vectorization
115- fn where [dtype : DType](inout x : NDArray[dtype], y : NDArray[dtype], mask :NDArray[DType.bool]) raises :
129+ fn where [
130+ dtype : DType
131+ ](inout x : NDArray[dtype], y : NDArray[dtype], mask : NDArray[DType.bool]) raises :
116132 """
117133 Replaces elements in `x` with elements from `y` where `mask` is True.
118134
@@ -121,7 +137,7 @@ fn where[dtype: DType](inout x: NDArray[dtype], y: NDArray[dtype], mask:NDArray[
121137
122138 Parameters:
123139 dtype: DType.
124-
140+
125141 Args:
126142 x: NDArray[dtype].
127143 y: NDArray[dtype].
@@ -135,13 +151,13 @@ fn where[dtype: DType](inout x: NDArray[dtype], y: NDArray[dtype], mask:NDArray[
135151 x.data.store(i, y.data[i])
136152
137153
138- fn flip [dtype : DType](inout array : NDArray[dtype]) raises -> NDArray[dtype]:
154+ fn flip [dtype : DType](array : NDArray[dtype]) raises -> NDArray[dtype]:
139155 """
140156 Flips the NDArray along the given axis.
141157
142158 Parameters:
143159 dtype: DType.
144-
160+
145161 Args:
146162 array: A NDArray.
147163
@@ -151,8 +167,9 @@ fn flip[dtype: DType](inout array: NDArray[dtype]) raises -> NDArray[dtype]:
151167 if array.ndim != 1 :
152168 raise Error(" Flip is only supported for 1D arrays" )
153169
154- var result : NDArray[dtype] = NDArray[dtype](shape = array.ndshape, order = array.order)
170+ var result : NDArray[dtype] = NDArray[dtype](
171+ shape = array.ndshape, order = array.order
172+ )
155173 for i in range (array.ndshape._size):
156174 result.data.store(i, array.data[array.ndshape._size - i - 1 ])
157-
158175 return result
0 commit comments