Skip to content

Commit 2b17633

Browse files
authored
Merge pull request #54 from shivasankarka/ndarray
Minor fixes
2 parents 4366c7b + ce80a3f commit 2b17633

File tree

12 files changed

+270
-208
lines changed

12 files changed

+270
-208
lines changed

numojo/core/ndarray.mojo

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -521,14 +521,14 @@ struct _NDArrayIter[
521521
var length: Int
522522

523523
fn __init__(
524-
inout self,
525-
unsafe_pointer: DTypePointer[dtype],
524+
inout self,
525+
unsafe_pointer: DTypePointer[dtype],
526526
length: Int,
527527
):
528528
self.index = 0 if forward else length
529529
self.ptr = unsafe_pointer
530530
self.length = length
531-
531+
532532
fn __iter__(self) -> Self:
533533
return self
534534

@@ -550,12 +550,15 @@ struct _NDArrayIter[
550550
else:
551551
return self.index
552552

553+
553554
# ===----------------------------------------------------------------------===#
554555
# NDArray
555556
# ===----------------------------------------------------------------------===#
556557

557558

558-
struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Sized):
559+
struct NDArray[dtype: DType = DType.float32](
560+
Stringable, CollectionElement, Sized
561+
):
559562
"""The N-dimensional array (NDArray).
560563
561564
The array can be uniquely defined by the following:
@@ -927,7 +930,7 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
927930
for i in range(index.__len__()):
928931
if index[i] >= self.ndshape[i]:
929932
raise Error("Error: Elements of `index` exceed the array shape")
930-
var idx: Int = _get_index(index, self.coefficient)
933+
var idx: Int = _get_index(index, self.stride)
931934
return self.data.load[width=1](idx)
932935

933936
fn __getitem__(self, index: List[Int]) raises -> SIMD[dtype, 1]:
@@ -1007,6 +1010,8 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
10071010
var count: Int = 0
10081011
for i in range(slices.__len__()):
10091012
self._adjust_slice_(slices[i], self.ndshape[i])
1013+
if slices[i].start >= self.ndshape[i] or slices[i].end > self.ndshape[i]:
1014+
raise Error("Error: Slice value exceeds the array shape")
10101015
spec.append(slices[i].unsafe_indices())
10111016
if slices[i].unsafe_indices() != 1:
10121017
ndims += 1
@@ -1024,7 +1029,7 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
10241029
count = 0
10251030
for _ in range(ndims):
10261031
while spec[j] == 1:
1027-
count+=1
1032+
count += 1
10281033
j += 1
10291034
if j >= self.ndim:
10301035
break
@@ -1041,8 +1046,6 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
10411046
var noffset: Int = 0
10421047
if self.order == "C":
10431048
noffset = 0
1044-
if ndims == 1:
1045-
nstrides.append(1)
10461049
for i in range(ndims):
10471050
var temp_stride: Int = 1
10481051
for j in range(i + 1, ndims): # temp
@@ -1281,7 +1284,7 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
12811284
)
12821285

12831286
fn __reversed__(self) -> _NDArrayIter[dtype, forward=False]:
1284-
"""Iterate backwards over elements of the NDArray, returning
1287+
"""Iterate backwards over elements of the NDArray, returning
12851288
copied value.
12861289
12871290
Returns:
@@ -1519,41 +1522,45 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
15191522
# We might need to figure out how we want to handle truthyness before can do this
15201523
alias nelts: Int = simdwidthof[dtype]()
15211524
var result: Bool = True
1525+
15221526
@parameter
15231527
fn vectorized_all[simd_width: Int](idx: Int) -> None:
1524-
result = result and allb(self.data.load[width=simd_width](idx) )
1528+
result = result and allb(self.data.load[width=simd_width](idx))
1529+
15251530
vectorize[vectorized_all, nelts](self.ndshape._size)
1526-
return result
1531+
return result
15271532

15281533
fn any(self) raises -> Bool:
15291534
# make this a compile time check
15301535
if not (self.dtype == DType.bool or is_inttype(dtype)):
15311536
raise Error("Array elements must be Boolean or Integer.")
15321537
alias nelts: Int = simdwidthof[dtype]()
1533-
var result: Bool = False
1538+
var result: Bool = False
1539+
15341540
@parameter
15351541
fn vectorized_any[simd_width: Int](idx: Int) -> None:
1536-
result = result or anyb(self.data.load[width=simd_width](idx) )
1542+
result = result or anyb(self.data.load[width=simd_width](idx))
1543+
15371544
vectorize[vectorized_any, nelts](self.ndshape._size)
15381545
return result
15391546

15401547
fn argmax(self) -> Int:
15411548
var result: Int = 0
15421549
var max_val: SIMD[dtype, 1] = self.load[width=1](0)
15431550
for i in range(1, self.ndshape._size):
1544-
var temp: SIMD[dtype, 1] = self.load[width=1](i)
1545-
if temp > max_val:
1546-
max_val = temp
1551+
var temp: SIMD[dtype, 1] = self.load[width=1](i)
1552+
if temp > max_val:
1553+
max_val = temp
15471554
result = i
15481555
return result
15491556

15501557
fn argmin(self) -> Int:
15511558
var result: Int = 0
15521559
var min_val: SIMD[dtype, 1] = self.load[width=1](0)
15531560
for i in range(1, self.ndshape._size):
1554-
var temp: SIMD[dtype, 1] = self.load[width=1](i)
1555-
if temp < min_val:
1556-
min_val = temp
1561+
var temp: SIMD[dtype, 1] = self.load[width=1](i)
1562+
if temp < min_val:
1563+
min_val = temp
15571564
result = i
15581565
return result
15591566

@@ -1563,13 +1570,16 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
15631570
fn astype[type: DType](inout self) raises -> NDArray[type]:
15641571
# I wonder if we can do this operation inplace instead of allocating memory.
15651572
alias nelts = simdwidthof[dtype]()
1566-
var narr: NDArray[type] = NDArray[type](self.ndshape, random=False, order=self.order)
1573+
var narr: NDArray[type] = NDArray[type](
1574+
self.ndshape, random=False, order=self.order
1575+
)
15671576
narr.datatype = type
1577+
15681578
@parameter
15691579
fn vectorized_astype[width: Int](idx: Int) -> None:
15701580
narr.store[width](idx, self.load[width](idx).cast[type]())
15711581

1572-
vectorize[vectorized_astype, nelts](self.ndshape._size)
1582+
vectorize[vectorized_astype, nelts](self.ndshape._size)
15731583
return narr
15741584

15751585
# fn clip(self):
@@ -1619,8 +1629,11 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
16191629
# self.stride = NDArrayStride(shape = self.ndshape, offset=0)
16201630
# return self
16211631

1622-
var res: NDArray[dtype] = NDArray[dtype](self.ndshape._size, random=False)
1632+
var res: NDArray[dtype] = NDArray[dtype](
1633+
self.ndshape._size, random=False
1634+
)
16231635
alias simd_width: Int = simdwidthof[dtype]()
1636+
16241637
@parameter
16251638
fn vectorized_flatten[simd_width: Int](index: Int) -> None:
16261639
res.data.store[width=simd_width](

numojo/core/ndarray_utils.mojo

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ fn _traverse_iterative[
6868
if depth == ndim.__len__():
6969
var idx = offset + _get_index(index, coefficients)
7070
var nidx = _get_index(index, strides)
71-
var temp = orig.data.load[width=1](
72-
idx
73-
)
71+
var temp = orig.data.load[width=1](idx)
7472
if nidx >= narr.ndshape._size:
7573
raise Error("Invalid index: index out of bound")
7674
else:

numojo/core/sort.mojo

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@ TODO:
2020
# Bubble sort
2121
# ===------------------------------------------------------------------------===#
2222

23-
fn bubble_sort[
24-
dtype: DType
25-
](
26-
ndarray: NDArray[dtype]
27-
) raises -> NDArray[dtype]:
23+
24+
fn bubble_sort[dtype: DType](ndarray: NDArray[dtype]) raises -> NDArray[dtype]:
2825
"""
2926
Bubble sort the NDArray.
3027
Average complexity: O(n^2) comparisons, O(n^2) swaps.
@@ -44,11 +41,11 @@ fn bubble_sort[
4441
var length = ndarray.size()
4542

4643
for i in range(length):
47-
for j in range(length-i-1):
48-
if result.data.load[width=1](j) > result.data.load[width=1](j+1):
44+
for j in range(length - i - 1):
45+
if result.data.load[width=1](j) > result.data.load[width=1](j + 1):
4946
var temp = result.data.load[width=1](j)
50-
result.data.store[width=1](j, result.data.load[width=1](j+1))
51-
result.data.store[width=1](j+1, temp)
47+
result.data.store[width=1](j, result.data.load[width=1](j + 1))
48+
result.data.store[width=1](j + 1, temp)
5249

5350
return result
5451

@@ -57,13 +54,10 @@ fn bubble_sort[
5754
# Quick sort
5855
# ===------------------------------------------------------------------------===#
5956

57+
6058
fn _partition(
61-
inout ndarray: NDArray,
62-
left: Int,
63-
right: Int,
64-
pivot_index: Int
65-
) raises -> Int:
66-
59+
inout ndarray: NDArray, left: Int, right: Int, pivot_index: Int
60+
) raises -> Int:
6761
var pivot_value = ndarray[pivot_index]
6862
ndarray[pivot_index], ndarray[right] = ndarray[right], ndarray[pivot_index]
6963
var store_index = left
@@ -73,14 +67,13 @@ fn _partition(
7367
ndarray[store_index], ndarray[i] = ndarray[i], ndarray[store_index]
7468
store_index = store_index + 1
7569
ndarray[right], ndarray[store_index] = ndarray[store_index], ndarray[right]
76-
70+
7771
return store_index
7872

79-
fn quick_sort_inplace[dtype: DType](
80-
inout ndarray: NDArray[dtype],
81-
left: Int,
82-
right: Int,
83-
) raises:
73+
74+
fn quick_sort_inplace[
75+
dtype: DType
76+
](inout ndarray: NDArray[dtype], left: Int, right: Int,) raises:
8477
"""
8578
Quick sort (in-place) the NDArray.
8679
@@ -96,12 +89,13 @@ fn quick_sort_inplace[dtype: DType](
9689
if right > left:
9790
var pivot_index = left + (right - left) // 2
9891
var pivot_new_index = _partition(ndarray, left, right, pivot_index)
99-
quick_sort_inplace(ndarray, left, pivot_new_index-1)
100-
quick_sort_inplace(ndarray, pivot_new_index+1, right)
92+
quick_sort_inplace(ndarray, left, pivot_new_index - 1)
93+
quick_sort_inplace(ndarray, pivot_new_index + 1, right)
10194

102-
fn quick_sort[dtype: DType](
103-
ndarray: NDArray[dtype],
104-
) raises -> NDArray[dtype]:
95+
96+
fn quick_sort[
97+
dtype: DType
98+
](ndarray: NDArray[dtype],) raises -> NDArray[dtype]:
10599
"""
106100
Quick sort the NDArray.
107101
Adopt in-place partition.
@@ -123,14 +117,16 @@ fn quick_sort[dtype: DType](
123117
var result: NDArray[dtype] = ndarray
124118
var length = ndarray.size()
125119

126-
quick_sort_inplace(result, 0, length-1)
120+
quick_sort_inplace(result, 0, length - 1)
127121

128122
return result
129123

124+
130125
# ===------------------------------------------------------------------------===#
131126
# Binary sort
132127
# ===------------------------------------------------------------------------===#
133128

129+
134130
fn binary_sort[
135131
in_dtype: DType, out_dtype: DType = DType.float64
136132
](array: NDArray[in_dtype]) raises -> NDArray[out_dtype]:
@@ -158,4 +154,4 @@ fn binary_sort[
158154
var temp: Scalar[out_dtype] = result[i - 1]
159155
result[i - 1] = result[i]
160156
result[i] = temp
161-
return result
157+
return result

numojo/core/utility_funcs.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ fn is_inttype(dtype: DType) -> Bool:
2727
return True
2828
return False
2929

30+
3031
fn is_floattype[dtype: DType]() -> Bool:
3132
if (
3233
dtype == DType.float16

numojo/math/__init__.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ from ._math_funcs import (
1010
VectorizedParallelized,
1111
VectorizedUnroll,
1212
VectorizedVerbose,
13-
VectorizedParallelizedNWorkers
13+
VectorizedParallelizedNWorkers,
1414
)

0 commit comments

Comments
 (0)