Skip to content

Commit e39211f

Browse files
authored
[update] sort function to allow sorting by axis > 2 for high dim arrays (#154)
This PR updates the `sort` function to allow: (1) Sorting by axis > 2 for high dimensional arrays. (2) Sorting row-major or col-major arrays. (3) flatten the array before sorting if axis is not given.
1 parent 4bac711 commit e39211f

File tree

6 files changed

+210
-83
lines changed

6 files changed

+210
-83
lines changed

numojo/__init__.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ from numojo.routines.manipulation import (
171171
from numojo.routines import random
172172

173173
from numojo.routines import sorting
174+
from numojo.routines.sorting import sort
174175

175176
from numojo.routines import searching
176177
from numojo.routines.searching import argmax, argmin

numojo/core/ndarray.mojo

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ from memory import memset_zero, memcpy
3030

3131

3232
import numojo.core._array_funcs as _af
33-
import numojo.routines.sorting as sort
33+
import numojo.routines.sorting as sorting
3434
import numojo.routines.math.arithmetic as arithmetic
3535
import numojo.routines.logic.comparison as comparison
3636
import numojo.routines.math.rounding as rounding
@@ -2090,13 +2090,13 @@ struct NDArray[dtype: DType = DType.float64](
20902090
"""
20912091
Sort the NDArray and return the sorted indices.
20922092
2093-
See `numojo.core.sort.argsort()`.
2093+
See `numojo.routines.sorting.argsort()`.
20942094
20952095
Returns:
20962096
The indices of the sorted NDArray.
20972097
"""
20982098

2099-
return sort.argsort(self)
2099+
return sorting.argsort(self)
21002100

21012101
fn astype[type: DType](self) raises -> NDArray[type]:
21022102
"""
@@ -2508,9 +2508,31 @@ struct NDArray[dtype: DType = DType.float64](
25082508

25092509
fn sort(mut self) raises:
25102510
"""
2511-
Sort the array inplace using quickstort.
2511+
Sort NDArray using quick sort method.
2512+
It is not guaranteed to be unstable.
2513+
2514+
When no axis is given, the array is flattened before sorting.
2515+
2516+
See `numojo.sorting.sort` for more information.
2517+
"""
2518+
var I = NDArray[DType.index](self.shape)
2519+
self = flatten(self)
2520+
sorting._sort_inplace(
2521+
self,
2522+
I,
2523+
)
2524+
2525+
fn sort(mut self, owned axis: Int) raises:
2526+
"""
2527+
Sort NDArray along the given axis using quick sort method.
2528+
It is not guaranteed to be unstable.
2529+
2530+
When no axis is given, the array is flattened before sorting.
2531+
2532+
See `numojo.sorting.sort` for more information.
25122533
"""
2513-
sort.quick_sort_inplace[dtype](self, 0, self.size - 1)
2534+
var I = NDArray[DType.index](self.shape)
2535+
sorting._sort_inplace(self, I, axis=axis)
25142536

25152537
fn sum(self: Self, axis: Int) raises -> Self:
25162538
"""

numojo/routines/manipulation.mojo

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ fn transpose[
203203
"""
204204
if len(axes) != A.ndim:
205205
raise Error(
206-
String("Length of axes {} does not match ndim of A {}").format(
207-
len(axes), A.ndim
208-
)
206+
String(
207+
"Length of `axes` ({}) does not match `ndim` of array ({})"
208+
).format(len(axes), A.ndim)
209209
)
210210

211211
for i in range(A.ndim):

numojo/routines/sorting.mojo

Lines changed: 150 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import math
77
from algorithm import vectorize
88

9-
from ..core.ndarray import NDArray, NDArrayShape
9+
from numojo.core.ndarray import NDArray
10+
from numojo.core.ndshape import NDArrayShape
11+
from numojo.routines.manipulation import flatten, transpose
1012

1113
"""
1214
TODO:
@@ -57,16 +59,25 @@ fn bubble_sort[dtype: DType](ndarray: NDArray[dtype]) raises -> NDArray[dtype]:
5759
return result
5860

5961

60-
# Quick sort
62+
##############
63+
# Quick sort #
64+
##############
6165

6266

63-
fn _partition(
64-
mut ndarray: NDArray, left: Int, right: Int, pivot_index: Int
67+
fn _partition_in_range(
68+
mut A: NDArray,
69+
mut I: NDArray,
70+
left: Int,
71+
right: Int,
72+
pivot_index: Int,
6573
) raises -> Int:
66-
"""Do partition for the data buffer of ndarray.
74+
"""
75+
Do in-place partition for array buffer within given range.
76+
Auxiliary function for `sort`, `argsort`, and `partition`.
6777
6878
Args:
69-
ndarray: An NDArray.
79+
A: NDArray.
80+
I: NDArray used to store indices.
7081
left: Left index of the partition.
7182
right: Right index of the partition.
7283
pivot_index: Input pivot index
@@ -75,81 +86,170 @@ fn _partition(
7586
New pivot index.
7687
"""
7788

78-
var pivot_value = ndarray.get(pivot_index)
79-
var _value_at_pivot = ndarray.get(pivot_index)
80-
ndarray.set(pivot_index, ndarray.get(right))
81-
ndarray.set(right, _value_at_pivot)
89+
# (Unsafe) Boundary checks are not done for sake of speed:
90+
# if (left >= A.size) or (right >= A.size) or (pivot_index >= A.size):
91+
92+
var pivot_value = A._buf[pivot_index]
93+
94+
A._buf[pivot_index], A._buf[right] = A._buf[right], A._buf[pivot_index]
95+
I._buf[pivot_index], I._buf[right] = I._buf[right], I._buf[pivot_index]
8296

8397
var store_index = left
8498

8599
for i in range(left, right):
86-
if ndarray.get(i) < pivot_value:
87-
var _value_at_store = ndarray.get(store_index)
88-
ndarray.set(store_index, ndarray.get(i))
89-
ndarray.set(i, _value_at_store)
100+
if A._buf[i] < pivot_value:
101+
A._buf[store_index], A._buf[i] = A._buf[i], A._buf[store_index]
102+
I._buf[store_index], I._buf[i] = I._buf[i], I._buf[store_index]
90103
store_index = store_index + 1
91104

92-
var _value_at_store = ndarray.get(store_index)
93-
ndarray.set(store_index, ndarray.get(right))
94-
ndarray.set(right, _value_at_store)
105+
A._buf[store_index], A._buf[right] = A._buf[right], A._buf[store_index]
106+
I._buf[store_index], I._buf[right] = I._buf[right], I._buf[store_index]
95107

96108
return store_index
97109

98110

99-
fn quick_sort_inplace[
100-
dtype: DType
101-
](mut ndarray: NDArray[dtype], left: Int, right: Int,) raises:
111+
fn _sort_in_range(mut A: NDArray, mut I: NDArray, left: Int, right: Int) raises:
102112
"""
103-
Quick sort (in-place) the NDArray.
104-
105-
Parameters:
106-
dtype: The input element type.
113+
Sort in-place of the data buffer (quick-sort) within give range.
114+
It is not guaranteed to be stable.
107115
108116
Args:
109-
ndarray: An NDArray.
117+
A: NDArray.
118+
I: NDArray used to store indices.
110119
left: Left index of the partition.
111120
right: Right index of the partition.
112121
"""
113122

114123
if right > left:
115124
var pivot_index = left + (right - left) // 2
116-
var pivot_new_index = _partition(ndarray, left, right, pivot_index)
117-
quick_sort_inplace(ndarray, left, pivot_new_index - 1)
118-
quick_sort_inplace(ndarray, pivot_new_index + 1, right)
125+
var pivot_new_index = _partition_in_range(
126+
A, I, left, right, pivot_index
127+
)
128+
_sort_in_range(A, I, left, pivot_new_index - 1)
129+
_sort_in_range(A, I, pivot_new_index + 1, right)
119130

120131

121-
fn quick_sort[dtype: DType](ndarray: NDArray[dtype]) raises -> NDArray[dtype]:
132+
fn _sort_inplace[
133+
dtype: DType
134+
](mut A: NDArray[dtype], mut I: NDArray[DType.index]) raises:
122135
"""
123-
Quick sort the NDArray.
124-
Adopt in-place partition.
125-
Average complexity: O(nlogn).
126-
Worst-case complexity: O(n^2).
127-
Worst-case space complexity: O(n).
128-
Unstable.
136+
Sort in-place NDArray using quick sort method.
137+
It is not guaranteed to be unstable.
129138
130-
Example:
131-
```py
132-
var arr = numojo.core.random.rand[numojo.i16](100)
133-
var sorted_arr = numojo.core.sort.quick_sort(arr)
134-
print(sorted_arr)
135-
```
139+
When no axis is given, the array is flattened before sorting.
136140
137141
Parameters:
138142
dtype: The input element type.
139143
140144
Args:
141-
ndarray: An NDArray.
145+
A: NDArray.
146+
I: NDArray that stores the indices.
147+
"""
148+
149+
A = flatten(A)
150+
_sort_in_range(A, I, 0, A.size - 1)
142151

152+
153+
fn _sort_inplace[
154+
dtype: DType
155+
](mut A: NDArray[dtype], mut I: NDArray[DType.index], owned axis: Int) raises:
143156
"""
157+
Sort in-place NDArray along the given axis using quick sort method.
158+
It is not guaranteed to be unstable.
144159
145-
var result: NDArray[dtype] = ndarray
146-
var length = ndarray.size
147-
quick_sort_inplace(result, 0, length - 1)
160+
When no axis is given, the array is flattened before sorting.
148161
149-
return result
162+
Parameters:
163+
dtype: The input element type.
164+
165+
Args:
166+
A: NDArray to sort.
167+
I: NDArray that stores the indices.
168+
axis: The axis along which the array is sorted.
169+
170+
"""
171+
172+
if axis < 0:
173+
axis = A.ndim + axis
174+
175+
if (axis >= A.ndim) or (axis < 0):
176+
raise Error(
177+
String("Axis {} is invalid for array of {} dimensions").format(
178+
axis, A.ndim
179+
)
180+
)
181+
182+
var continous_axis = A.ndim - 1 if A.order == "C" else A.ndim - 2
183+
"""Continuously stored axis. -1 if row-major, -2 if col-major."""
184+
185+
if axis == continous_axis: # Last axis
186+
var I = zeros[DType.index](shape=A.shape)
187+
for i in range(A.size // A.shape[continous_axis]):
188+
_sort_in_range(
189+
A,
190+
I,
191+
left=i * A.shape[continous_axis],
192+
right=(i + 1) * A.shape[continous_axis] - 1,
193+
)
194+
else:
195+
var transposed_axes = List[Int](capacity=A.ndim)
196+
for i in range(A.ndim):
197+
transposed_axes.append(i)
198+
transposed_axes[axis], transposed_axes[continous_axis] = (
199+
transposed_axes[continous_axis],
200+
transposed_axes[axis],
201+
)
202+
A = transpose(A, axes=transposed_axes)
203+
_sort_inplace(A, I, axis=-1)
204+
A = transpose(A, axes=transposed_axes)
205+
206+
207+
fn sort[dtype: DType](owned A: NDArray[dtype]) raises -> NDArray[dtype]:
208+
"""
209+
Sort NDArray using quick sort method.
210+
It is not guaranteed to be unstable.
211+
212+
When no axis is given, the array is flattened before sorting.
213+
214+
Parameters:
215+
dtype: The input element type.
150216
217+
Args:
218+
A: NDArray.
219+
"""
220+
221+
var I = NDArray[DType.index](A.shape)
222+
A = flatten(A)
223+
_sort_inplace(A, I)
224+
return A^
225+
226+
227+
fn sort[
228+
dtype: DType
229+
](owned A: NDArray[dtype], owned axis: Int) raises -> NDArray[dtype]:
230+
"""
231+
Sort NDArray along the given axis using quick sort method.
232+
It is not guaranteed to be unstable.
151233
152-
# Binary sort
234+
When no axis is given, the array is flattened before sorting.
235+
236+
Parameters:
237+
dtype: The input element type.
238+
239+
Args:
240+
A: NDArray to sort.
241+
axis: The axis along which the array is sorted.
242+
243+
"""
244+
245+
var I = NDArray[DType.index](A.shape)
246+
_sort_inplace(A, I, axis)
247+
return A^
248+
249+
250+
###############
251+
# Binary sort #
252+
###############
153253

154254

155255
fn binary_sort[
@@ -193,7 +293,9 @@ fn binary_sort[
193293
return result
194294

195295

196-
# Argsort using quick sort algorithm
296+
# ===----------------------------------------------------------------------=== #
297+
# Searching
298+
# ===----------------------------------------------------------------------=== #
197299

198300

199301
fn _argsort_partition(

tests/test_sort.mojo

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)