Skip to content

Commit 8186939

Browse files
authored
Merge pull request #53 from MadAlex1997/mad_ndarray_compare
Add inout to flatten using Optional, fix comparison functions and min…
2 parents 4f940d7 + d4bd78b commit 8186939

File tree

5 files changed

+189
-35
lines changed

5 files changed

+189
-35
lines changed

numojo/core/array_creation_routines.mojo

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ fn arange[
4444
Returns:
4545
NDArray[dtype] - NDArray of datatype T with elements ranging from "start" to "stop" incremented with "step".
4646
"""
47-
if is_inttype[in_dtype]() and is_inttype[out_dtype]():
47+
if is_floattype[in_dtype]() and is_inttype[out_dtype]():
4848
raise Error(
49-
"Input and output cannot be `Int` datatype as it may lead to"
50-
" precision errors"
49+
"""
50+
If in_dtype is a float then out_dtype must also be a float
51+
"""
5152
)
5253

5354
var num: Int = ((stop - start) / step).__int__()

numojo/core/ndarray.mojo

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ from ..math.statistics.cumulative_reduce import (
3434
)
3535
from ..math.check import any, all
3636
from ..math.arithmetic import abs
37-
from .ndarray_utils import _get_index, _traverse_iterative, to_numpy
37+
from .ndarray_utils import _get_index, _traverse_iterative, to_numpy, bool_to_numeric
3838
from .utility_funcs import is_inttype
3939
from ..math.linalg.matmul import matmul_parallelized
4040

@@ -1612,7 +1612,7 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
16121612
vectorize[vectorized_fill, simd_width](self.ndshape._size)
16131613
return self
16141614

1615-
fn flatten(inout self, inplace: Bool = False) raises -> Self:
1615+
fn flatten(inout self, inplace: Bool = False) raises -> Optional[Self]:
16161616
# inplace has some problems right now
16171617
# if inplace:
16181618
# self.ndshape = NDArrayShape(self.ndshape._size, size=self.ndshape._size)
@@ -1628,15 +1628,18 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
16281628
)
16291629

16301630
vectorize[vectorized_flatten, simd_width](self.ndshape._size)
1631-
return res
1631+
if inplace:
1632+
self = res
1633+
return None
1634+
else:
1635+
return res
16321636

16331637
fn item(self, *indices: Int) raises -> SIMD[dtype, 1]: # I should add
16341638
if indices.__len__() == 1:
16351639
return self.data.load[width=1](indices[0])
16361640
else:
16371641
return self.data.load[width=1](_get_index(indices, self.stride))
16381642

1639-
#TODO: not finished yet
16401643
fn max(self, axis: Int = 0) raises -> Self:
16411644
var ndim: Int = self.ndim
16421645
var shape: List[Int] = List[Int]()
@@ -1655,16 +1658,49 @@ struct NDArray[dtype: DType = DType.float32](Stringable, CollectionElement, Size
16551658
slices.append(Slice(0, 0))
16561659
print(result_shape.__str__())
16571660
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(result_shape))
1658-
1659-
for i in range(axis_size):
1661+
slices[axis] = Slice(0, 1)
1662+
result = self[slices]
1663+
for i in range(1,axis_size):
16601664
slices[axis] = Slice(i, i + 1)
16611665
var arr_slice = self[slices]
1662-
result += maxT(arr_slice)
1666+
var mask1 = greater(arr_slice,result)
1667+
var mask2 = less(arr_slice,result)
1668+
# Wherever result is less than the new slice it is set to zero
1669+
# Wherever arr_slice is greater than the old result it is added to fill those zeros
1670+
result = add(result * bool_to_numeric[dtype](mask2),arr_slice * bool_to_numeric[dtype](mask1))
16631671

16641672
return result
16651673

1666-
fn min(self, axis: Int = 0):
1667-
pass
1674+
fn min(self, axis: Int = 0)raises-> Self:
1675+
var ndim: Int = self.ndim
1676+
var shape: List[Int] = List[Int]()
1677+
for i in range(ndim):
1678+
shape.append(self.ndshape[i])
1679+
if axis > ndim - 1:
1680+
raise Error("axis cannot be greater than the rank of the array")
1681+
var result_shape: List[Int] = List[Int]()
1682+
var axis_size: Int = shape[axis]
1683+
var slices: List[Slice] = List[Slice]()
1684+
for i in range(ndim):
1685+
if i != axis:
1686+
result_shape.append(shape[i])
1687+
slices.append(Slice(0, shape[i]))
1688+
else:
1689+
slices.append(Slice(0, 0))
1690+
1691+
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(result_shape))
1692+
slices[axis] = Slice(0, 1)
1693+
result = self[slices]
1694+
for i in range(1,axis_size):
1695+
slices[axis] = Slice(i, i + 1)
1696+
var arr_slice = self[slices]
1697+
var mask1 = less(arr_slice,result)
1698+
var mask2 = greater(arr_slice,result)
1699+
# Wherever result is greater than the new slice it is set to zero
1700+
# Wherever arr_slice is less than the old result it is added to fill those zeros
1701+
result = add(result * bool_to_numeric[dtype](mask2),arr_slice * bool_to_numeric[dtype](mask1))
1702+
1703+
return result
16681704

16691705
fn mean(self: Self, axis: Int) raises -> Self:
16701706
"""

numojo/core/ndarray_utils.mojo

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ fn _traverse_iterative[
8484
orig, narr, ndim, coefficients, strides, offset, index, newdepth
8585
)
8686

87+
fn bool_to_numeric[dtype:DType](array: NDArray[DType.bool])raises->NDArray[dtype]:
88+
# Can't use simd becuase of bit packing error
89+
var res: NDArray[dtype] = NDArray[dtype](array.shape())
90+
for i in range(array.size()):
91+
var t = array[i]
92+
if t:
93+
res[i] = 1
94+
else:
95+
res[i] = 0
96+
return res
8797

8898
fn to_numpy[dtype: DType](array: NDArray[dtype]) raises -> PythonObject:
8999
try:

numojo/math/comparison.mojo

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,17 @@ fn greater[
3232
3333
An element of the result NDArray will be True if the corresponding element in x is greater than the corresponding element in y, and False otherwise.
3434
"""
35-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__gt__](
36-
array1, array2
37-
)
35+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__gt__](
36+
# array1, array2
37+
# )
38+
if array1.shape() != array2.shape():
39+
raise Error(
40+
"Shape Mismatch error shapes must match for this function"
41+
)
42+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
43+
for i in range(result_array.size()):
44+
result_array[i] = array1[i]>array2[i]
45+
return result_array
3846

3947

4048
fn greater_equal[
@@ -56,10 +64,17 @@ fn greater_equal[
5664
5765
An element of the result NDArray will be True if the corresponding element in x is greater than or equal to the corresponding element in y, and False otherwise.
5866
"""
59-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__ge__](
60-
array1, array2
61-
)
62-
67+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__ge__](
68+
# array1, array2
69+
# )
70+
if array1.shape() != array2.shape():
71+
raise Error(
72+
"Shape Mismatch error shapes must match for this function"
73+
)
74+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
75+
for i in range(result_array.size()):
76+
result_array[i] = array1[i]>=array2[i]
77+
return result_array
6378

6479
fn less[
6580
dtype: DType, backend: _mf.Backend = _mf.Vectorized
@@ -80,10 +95,17 @@ fn less[
8095
8196
An element of the result NDArray will be True if the corresponding element in x is or equal to the corresponding element in y, and False otherwise.
8297
"""
83-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__lt__](
84-
array1, array2
85-
)
86-
98+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__lt__](
99+
# array1, array2
100+
# )
101+
if array1.shape() != array2.shape():
102+
raise Error(
103+
"Shape Mismatch error shapes must match for this function"
104+
)
105+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
106+
for i in range(result_array.size()):
107+
result_array[i] = array1[i]<array2[i]
108+
return result_array
87109

88110
fn less_equal[
89111
dtype: DType, backend: _mf.Backend = _mf.Vectorized
@@ -104,9 +126,17 @@ fn less_equal[
104126
105127
An element of the result NDArray will be True if the corresponding element in x is less than or equal to the corresponding element in y, and False otherwise.
106128
"""
107-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__le__](
108-
array1, array2
109-
)
129+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__le__](
130+
# array1, array2
131+
# )
132+
if array1.shape() != array2.shape():
133+
raise Error(
134+
"Shape Mismatch error shapes must match for this function"
135+
)
136+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
137+
for i in range(result_array.size()):
138+
result_array[i] = array1[i]<=array2[i]
139+
return result_array
110140

111141

112142
fn equal[
@@ -128,10 +158,17 @@ fn equal[
128158
129159
An element of the result NDArray will be True if the corresponding element in x is equal to the corresponding element in y, and False otherwise.
130160
"""
131-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__eq__](
132-
array1, array2
133-
)
134-
161+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__eq__](
162+
# array1, array2
163+
# )
164+
if array1.shape() != array2.shape():
165+
raise Error(
166+
"Shape Mismatch error shapes must match for this function"
167+
)
168+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
169+
for i in range(result_array.size()):
170+
result_array[i] = array1[i]==array2[i]
171+
return result_array
135172

136173
fn not_equal[
137174
dtype: DType, backend: _mf.Backend = _mf.Vectorized
@@ -152,6 +189,14 @@ fn not_equal[
152189
153190
An element of the result NDArray will be True if the corresponding element in x is not equal to the corresponding element in y, and False otherwise.
154191
"""
155-
return backend()._math_func_compare_2_tensors[dtype, SIMD.__ne__](
156-
array1, array2
157-
)
192+
# return backend()._math_func_compare_2_tensors[dtype, SIMD.__ne__](
193+
# array1, array2
194+
# )
195+
if array1.shape() != array2.shape():
196+
raise Error(
197+
"Shape Mismatch error shapes must match for this function"
198+
)
199+
var result_array: NDArray[DType.bool] = NDArray[DType.bool](array1.shape())
200+
for i in range(result_array.size()):
201+
result_array[i] = array1[i]!=array2[i]
202+
return result_array

numojo/math/statistics/stats.mojo

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
# from numojo.core.NDArray import NDArray
88
from ...core.ndarray import NDArray
9-
9+
from .. import mul
1010

1111
fn sum(array: NDArray, axis: Int = 0) raises -> NDArray[array.dtype]:
1212
"""
@@ -96,7 +96,7 @@ fn prod(array: NDArray, axis: Int = 0) raises -> NDArray[array.dtype]:
9696
for i in range(1, axis_size):
9797
slices[axis] = Slice(i, i + 1)
9898
var arr_slice = array[slices]
99-
result = result * arr_slice
99+
result = mul[array.dtype](result, arr_slice)
100100

101101
return result
102102

@@ -137,3 +137,65 @@ fn meanall(array: NDArray) raises -> Float64:
137137
"""
138138

139139
return sumall(array).cast[DType.float64]() / Int32(array.ndshape._size).cast[DType.float64]()
140+
141+
fn max[dtype:DType](array:NDArray[dtype], axis: Int = 0) raises -> NDArray[dtype]:
142+
var ndim: Int = array.ndim
143+
var shape: List[Int] = List[Int]()
144+
for i in range(ndim):
145+
shape.append(array.ndshape[i])
146+
if axis > ndim - 1:
147+
raise Error("axis cannot be greater than the rank of the array")
148+
var result_shape: List[Int] = List[Int]()
149+
var axis_size: Int = shape[axis]
150+
var slices: List[Slice] = List[Slice]()
151+
for i in range(ndim):
152+
if i != axis:
153+
result_shape.append(shape[i])
154+
slices.append(Slice(0, shape[i]))
155+
else:
156+
slices.append(Slice(0, 0))
157+
print(result_shape.__str__())
158+
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(result_shape))
159+
slices[axis] = Slice(0, 1)
160+
result = array[slices]
161+
for i in range(1,axis_size):
162+
slices[axis] = Slice(i, i + 1)
163+
var arr_slice = array[slices]
164+
var mask1 = greater(arr_slice,result)
165+
var mask2 = less(arr_slice,result)
166+
# Wherever result is less than the new slice it is set to zero
167+
# Wherever arr_slice is greater than the old result it is added to fill those zeros
168+
result = add(result * bool_to_numeric[dtype](mask2),arr_slice * bool_to_numeric[dtype](mask1))
169+
170+
return result
171+
172+
fn min[dtype:DType](array:NDArray[dtype], axis: Int = 0)raises-> NDArray[dtype]:
173+
var ndim: Int = array.ndim
174+
var shape: List[Int] = List[Int]()
175+
for i in range(ndim):
176+
shape.append(array.ndshape[i])
177+
if axis > ndim - 1:
178+
raise Error("axis cannot be greater than the rank of the array")
179+
var result_shape: List[Int] = List[Int]()
180+
var axis_size: Int = shape[axis]
181+
var slices: List[Slice] = List[Slice]()
182+
for i in range(ndim):
183+
if i != axis:
184+
result_shape.append(shape[i])
185+
slices.append(Slice(0, shape[i]))
186+
else:
187+
slices.append(Slice(0, 0))
188+
189+
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(result_shape))
190+
slices[axis] = Slice(0, 1)
191+
result = array[slices]
192+
for i in range(1,axis_size):
193+
slices[axis] = Slice(i, i + 1)
194+
var arr_slice = array[slices]
195+
var mask1 = less(arr_slice,result)
196+
var mask2 = greater(arr_slice,result)
197+
# Wherever result is greater than the new slice it is set to zero
198+
# Wherever arr_slice is less than the old result it is added to fill those zeros
199+
result = add(result * bool_to_numeric[dtype](mask2),arr_slice * bool_to_numeric[dtype](mask1))
200+
201+
return result

0 commit comments

Comments
 (0)