@@ -155,3 +155,99 @@ fn binary_sort[
155155 result[i - 1 ] = result[i]
156156 result[i] = temp
157157 return result
158+
159+
160+ # ===------------------------------------------------------------------------===#
161+ # Argsort using quick sort algorithm
162+ # ===------------------------------------------------------------------------===#
163+
164+
165+ fn _argsort_partition (
166+ inout ndarray : NDArray,
167+ inout idx_array : NDArray,
168+ left : Int,
169+ right : Int,
170+ pivot_index : Int,
171+ ) raises -> Int:
172+ var pivot_value = ndarray[pivot_index]
173+ ndarray[pivot_index], ndarray[right] = ndarray[right], ndarray[pivot_index]
174+ idx_array[pivot_index], idx_array[right] = (
175+ idx_array[right],
176+ idx_array[pivot_index],
177+ )
178+ var store_index = left
179+
180+ for i in range (left, right):
181+ if ndarray[i] < pivot_value:
182+ ndarray[store_index], ndarray[i] = ndarray[i], ndarray[store_index]
183+ idx_array[store_index], idx_array[i] = (
184+ idx_array[i],
185+ idx_array[store_index],
186+ )
187+ store_index = store_index + 1
188+
189+ ndarray[right], ndarray[store_index] = ndarray[store_index], ndarray[right]
190+ idx_array[right], idx_array[store_index] = (
191+ idx_array[store_index],
192+ idx_array[right],
193+ )
194+
195+ return store_index
196+
197+
198+ fn argsort_inplace [
199+ dtype : DType
200+ ](
201+ inout ndarray : NDArray[dtype],
202+ inout idx_array : NDArray[DType.index],
203+ left : Int,
204+ right : Int,
205+ ) raises :
206+ """
207+ Conduct Argsort (in-place) based on the NDArray using quick sort.
208+
209+ Parameters:
210+ dtype: The input element type.
211+
212+ Args:
213+ ndarray: An NDArray.
214+ idx_array: An NDArray of the indices.
215+ left: Left index of the partition.
216+ right: Right index of the partition.
217+ """
218+
219+ if right > left:
220+ var pivot_index = left + (right - left) // 2
221+ var pivot_new_index = _argsort_partition(
222+ ndarray, idx_array, left, right, pivot_index
223+ )
224+ argsort_inplace(ndarray, idx_array, left, pivot_new_index - 1 )
225+ argsort_inplace(ndarray, idx_array, pivot_new_index + 1 , right)
226+
227+
228+ fn argsort [
229+ dtype : DType
230+ ](ndarray : NDArray[dtype],) raises -> NDArray[DType.index]:
231+ """
232+ Argsort of the NDArray using quick sort algorithm.
233+
234+ Parameters:
235+ dtype: The input element type.
236+
237+ Args:
238+ ndarray: An NDArray.
239+
240+ Returns:
241+ The indices of the sorted NDArray.
242+ """
243+
244+ var array : NDArray[dtype] = ndarray
245+ var length = array.size()
246+
247+ var idx_array = NDArray[DType.index](length)
248+ for i in range (length):
249+ idx_array[i] = i
250+
251+ argsort_inplace(array, idx_array, 0 , length - 1 )
252+
253+ return idx_array
0 commit comments