Skip to content

Commit 66c5d89

Browse files
authored
Merge pull request #57 from forFudan/zyharray
[lib] Implement the `argsort` functionality
2 parents 2d3d29b + 9a74fa0 commit 66c5d89

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

numojo/core/sort.mojo

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,99 @@ fn binary_sort[
155155
result[i - 1] = result[i]
156156
result[i] = temp
157157
return result
158+
159+
160+
# ===------------------------------------------------------------------------===#
161+
# Argsort using quick sort algorithm
162+
# ===------------------------------------------------------------------------===#
163+
164+
165+
fn _argsort_partition(
166+
inout ndarray: NDArray,
167+
inout idx_array: NDArray,
168+
left: Int,
169+
right: Int,
170+
pivot_index: Int,
171+
) raises -> Int:
172+
var pivot_value = ndarray[pivot_index]
173+
ndarray[pivot_index], ndarray[right] = ndarray[right], ndarray[pivot_index]
174+
idx_array[pivot_index], idx_array[right] = (
175+
idx_array[right],
176+
idx_array[pivot_index],
177+
)
178+
var store_index = left
179+
180+
for i in range(left, right):
181+
if ndarray[i] < pivot_value:
182+
ndarray[store_index], ndarray[i] = ndarray[i], ndarray[store_index]
183+
idx_array[store_index], idx_array[i] = (
184+
idx_array[i],
185+
idx_array[store_index],
186+
)
187+
store_index = store_index + 1
188+
189+
ndarray[right], ndarray[store_index] = ndarray[store_index], ndarray[right]
190+
idx_array[right], idx_array[store_index] = (
191+
idx_array[store_index],
192+
idx_array[right],
193+
)
194+
195+
return store_index
196+
197+
198+
fn argsort_inplace[
199+
dtype: DType
200+
](
201+
inout ndarray: NDArray[dtype],
202+
inout idx_array: NDArray[DType.index],
203+
left: Int,
204+
right: Int,
205+
) raises:
206+
"""
207+
Conduct Argsort (in-place) based on the NDArray using quick sort.
208+
209+
Parameters:
210+
dtype: The input element type.
211+
212+
Args:
213+
ndarray: An NDArray.
214+
idx_array: An NDArray of the indices.
215+
left: Left index of the partition.
216+
right: Right index of the partition.
217+
"""
218+
219+
if right > left:
220+
var pivot_index = left + (right - left) // 2
221+
var pivot_new_index = _argsort_partition(
222+
ndarray, idx_array, left, right, pivot_index
223+
)
224+
argsort_inplace(ndarray, idx_array, left, pivot_new_index - 1)
225+
argsort_inplace(ndarray, idx_array, pivot_new_index + 1, right)
226+
227+
228+
fn argsort[
229+
dtype: DType
230+
](ndarray: NDArray[dtype],) raises -> NDArray[DType.index]:
231+
"""
232+
Argsort of the NDArray using quick sort algorithm.
233+
234+
Parameters:
235+
dtype: The input element type.
236+
237+
Args:
238+
ndarray: An NDArray.
239+
240+
Returns:
241+
The indices of the sorted NDArray.
242+
"""
243+
244+
var array: NDArray[dtype] = ndarray
245+
var length = array.size()
246+
247+
var idx_array = NDArray[DType.index](length)
248+
for i in range(length):
249+
idx_array[i] = i
250+
251+
argsort_inplace(array, idx_array, 0, length - 1)
252+
253+
return idx_array

tests/argsort.mojo

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Test file for
2+
# numojo.core.sort.argsort()
3+
4+
import numojo as nm
5+
from numojo.core.ndarray import NDArray
6+
import time
7+
8+
fn main() raises:
9+
test[nm.f64](6)
10+
test[nm.i32](12)
11+
test[nm.f64](1000)
12+
13+
fn test[dtype: DType](length: Int) raises:
14+
# Initialize an ND arrays of type
15+
var t0 = time.now()
16+
var A = NDArray[dtype](length, random=True)
17+
print(A)
18+
print(nm.core.sort.argsort(A))
19+
print((time.now() - t0)/1e9, "s")

0 commit comments

Comments
 (0)