66import math
77from algorithm import vectorize
88
9- from ..core.ndarray import NDArray, NDArrayShape
9+ from numojo.core.ndarray import NDArray
10+ from numojo.core.ndshape import NDArrayShape
11+ from numojo.routines.manipulation import flatten, transpose
1012
1113"""
1214TODO:
@@ -57,16 +59,25 @@ fn bubble_sort[dtype: DType](ndarray: NDArray[dtype]) raises -> NDArray[dtype]:
5759 return result
5860
5961
60- # Quick sort
62+ # #############
63+ # Quick sort #
64+ # #############
6165
6266
63- fn _partition (
64- mut ndarray : NDArray, left : Int, right : Int, pivot_index : Int
67+ fn _partition_in_range (
68+ mut A : NDArray,
69+ mut I : NDArray,
70+ left : Int,
71+ right : Int,
72+ pivot_index : Int,
6573) raises -> Int:
66- """ Do partition for the data buffer of ndarray.
74+ """
75+ Do in-place partition for array buffer within given range.
76+ Auxiliary function for `sort`, `argsort`, and `partition`.
6777
6878 Args:
69- ndarray: An NDArray.
79+ A: NDArray.
80+ I: NDArray used to store indices.
7081 left: Left index of the partition.
7182 right: Right index of the partition.
7283 pivot_index: Input pivot index
@@ -75,81 +86,170 @@ fn _partition(
7586 New pivot index.
7687 """
7788
78- var pivot_value = ndarray.get(pivot_index)
79- var _value_at_pivot = ndarray.get(pivot_index)
80- ndarray.set(pivot_index, ndarray.get(right))
81- ndarray.set(right, _value_at_pivot)
89+ # (Unsafe) Boundary checks are not done for sake of speed:
90+ # if (left >= A.size) or (right >= A.size) or (pivot_index >= A.size):
91+
92+ var pivot_value = A._buf[pivot_index]
93+
94+ A._buf[pivot_index], A._buf[right] = A._buf[right], A._buf[pivot_index]
95+ I._buf[pivot_index], I._buf[right] = I._buf[right], I._buf[pivot_index]
8296
8397 var store_index = left
8498
8599 for i in range (left, right):
86- if ndarray.get(i) < pivot_value:
87- var _value_at_store = ndarray.get(store_index)
88- ndarray.set(store_index, ndarray.get(i))
89- ndarray.set(i, _value_at_store)
100+ if A._buf[i] < pivot_value:
101+ A._buf[store_index], A._buf[i] = A._buf[i], A._buf[store_index]
102+ I._buf[store_index], I._buf[i] = I._buf[i], I._buf[store_index]
90103 store_index = store_index + 1
91104
92- var _value_at_store = ndarray.get(store_index)
93- ndarray.set(store_index, ndarray.get(right))
94- ndarray.set(right, _value_at_store)
105+ A._buf[store_index], A._buf[right] = A._buf[right], A._buf[store_index]
106+ I._buf[store_index], I._buf[right] = I._buf[right], I._buf[store_index]
95107
96108 return store_index
97109
98110
99- fn quick_sort_inplace [
100- dtype : DType
101- ](mut ndarray : NDArray[dtype], left : Int, right : Int,) raises :
111+ fn _sort_in_range (mut A : NDArray, mut I : NDArray, left : Int, right : Int) raises :
102112 """
103- Quick sort (in-place) the NDArray.
104-
105- Parameters:
106- dtype: The input element type.
113+ Sort in-place of the data buffer (quick-sort) within give range.
114+ It is not guaranteed to be stable.
107115
108116 Args:
109- ndarray: An NDArray.
117+ A: NDArray.
118+ I: NDArray used to store indices.
110119 left: Left index of the partition.
111120 right: Right index of the partition.
112121 """
113122
114123 if right > left:
115124 var pivot_index = left + (right - left) // 2
116- var pivot_new_index = _partition(ndarray, left, right, pivot_index)
117- quick_sort_inplace(ndarray, left, pivot_new_index - 1 )
118- quick_sort_inplace(ndarray, pivot_new_index + 1 , right)
125+ var pivot_new_index = _partition_in_range(
126+ A, I, left, right, pivot_index
127+ )
128+ _sort_in_range(A, I, left, pivot_new_index - 1 )
129+ _sort_in_range(A, I, pivot_new_index + 1 , right)
119130
120131
121- fn quick_sort [dtype : DType](ndarray : NDArray[dtype]) raises -> NDArray[dtype]:
132+ fn _sort_inplace [
133+ dtype : DType
134+ ](mut A : NDArray[dtype], mut I : NDArray[DType.index]) raises :
122135 """
123- Quick sort the NDArray.
124- Adopt in-place partition.
125- Average complexity: O(nlogn).
126- Worst-case complexity: O(n^2).
127- Worst-case space complexity: O(n).
128- Unstable.
136+ Sort in-place NDArray using quick sort method.
137+ It is not guaranteed to be unstable.
129138
130- Example:
131- ```py
132- var arr = numojo.core.random.rand[numojo.i16](100)
133- var sorted_arr = numojo.core.sort.quick_sort(arr)
134- print(sorted_arr)
135- ```
139+ When no axis is given, the array is flattened before sorting.
136140
137141 Parameters:
138142 dtype: The input element type.
139143
140144 Args:
141- ndarray: An NDArray.
145+ A: NDArray.
146+ I: NDArray that stores the indices.
147+ """
148+
149+ A = flatten(A)
150+ _sort_in_range(A, I, 0 , A.size - 1 )
142151
152+
153+ fn _sort_inplace [
154+ dtype : DType
155+ ](mut A : NDArray[dtype], mut I : NDArray[DType.index], owned axis : Int) raises :
143156 """
157+ Sort in-place NDArray along the given axis using quick sort method.
158+ It is not guaranteed to be unstable.
144159
145- var result : NDArray[dtype] = ndarray
146- var length = ndarray.size
147- quick_sort_inplace(result, 0 , length - 1 )
160+ When no axis is given, the array is flattened before sorting.
148161
149- return result
162+ Parameters:
163+ dtype: The input element type.
164+
165+ Args:
166+ A: NDArray to sort.
167+ I: NDArray that stores the indices.
168+ axis: The axis along which the array is sorted.
169+
170+ """
171+
172+ if axis < 0 :
173+ axis = A.ndim + axis
174+
175+ if (axis >= A.ndim) or (axis < 0 ):
176+ raise Error(
177+ String(" Axis {} is invalid for array of {} dimensions" ).format(
178+ axis, A.ndim
179+ )
180+ )
181+
182+ var continous_axis = A.ndim - 1 if A.order == " C" else A.ndim - 2
183+ """ Continuously stored axis. -1 if row-major, -2 if col-major."""
184+
185+ if axis == continous_axis: # Last axis
186+ var I = zeros[DType.index](shape = A.shape)
187+ for i in range (A.size // A.shape[continous_axis]):
188+ _sort_in_range(
189+ A,
190+ I,
191+ left = i * A.shape[continous_axis],
192+ right = (i + 1 ) * A.shape[continous_axis] - 1 ,
193+ )
194+ else :
195+ var transposed_axes = List[Int](capacity = A.ndim)
196+ for i in range (A.ndim):
197+ transposed_axes.append(i)
198+ transposed_axes[axis], transposed_axes[continous_axis] = (
199+ transposed_axes[continous_axis],
200+ transposed_axes[axis],
201+ )
202+ A = transpose(A, axes = transposed_axes)
203+ _sort_inplace(A, I, axis = - 1 )
204+ A = transpose(A, axes = transposed_axes)
205+
206+
207+ fn sort [dtype : DType](owned A : NDArray[dtype]) raises -> NDArray[dtype]:
208+ """
209+ Sort NDArray using quick sort method.
210+ It is not guaranteed to be unstable.
211+
212+ When no axis is given, the array is flattened before sorting.
213+
214+ Parameters:
215+ dtype: The input element type.
150216
217+ Args:
218+ A: NDArray.
219+ """
220+
221+ var I = NDArray[DType.index](A.shape)
222+ A = flatten(A)
223+ _sort_inplace(A, I)
224+ return A^
225+
226+
227+ fn sort [
228+ dtype : DType
229+ ](owned A : NDArray[dtype], owned axis : Int) raises -> NDArray[dtype]:
230+ """
231+ Sort NDArray along the given axis using quick sort method.
232+ It is not guaranteed to be unstable.
151233
152- # Binary sort
234+ When no axis is given, the array is flattened before sorting.
235+
236+ Parameters:
237+ dtype: The input element type.
238+
239+ Args:
240+ A: NDArray to sort.
241+ axis: The axis along which the array is sorted.
242+
243+ """
244+
245+ var I = NDArray[DType.index](A.shape)
246+ _sort_inplace(A, I, axis)
247+ return A^
248+
249+
250+ # ##############
251+ # Binary sort #
252+ # ##############
153253
154254
155255fn binary_sort [
@@ -193,7 +293,9 @@ fn binary_sort[
193293 return result
194294
195295
196- # Argsort using quick sort algorithm
296+ # ===----------------------------------------------------------------------=== #
297+ # Searching
298+ # ===----------------------------------------------------------------------=== #
197299
198300
199301fn _argsort_partition (
0 commit comments