Skip to content

Commit 4d18575

Browse files
authored
Merge pull request #79 from shivasankarka/experimental
Array manipulation routines, booleans masks and Fixes
2 parents 413ca6d + 97b9d70 commit 4d18575

19 files changed

+997
-283
lines changed

.gitignore

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
/dist
44
/local
55
/.vscode
6-
7-
# Compiled package
8-
numojo.mojopkg
6+
.DS_Store
7+
*.html
8+
*.css
9+
*.py
10+
mojo
11+
numojo.mojopkg
12+
.gitignore
13+
bench.mojo
14+
test_ndarray.ipynb

numojo/core/array_creation_routines.mojo

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ fn arange[
3838
Error if both dtype and dtype are integers or if dtype is a float and dtype is an integer.
3939
4040
Parameters:
41-
dtype: Input datatype of the input values.
42-
dtype: Output datatype of the output NDArray.
41+
dtype: Datatype of the output array.
4342
4443
Args:
4544
start: Scalar[dtype] - Start value.
@@ -63,7 +62,9 @@ fn arange[
6362
NDArrayShape(num, size=num)
6463
)
6564
for idx in range(num):
66-
result.data[idx] = start.cast[dtype]() + step.cast[dtype]() * idx
65+
result.data[idx] = (
66+
start.cast[dtype]() + step.cast[dtype]() * idx
67+
)
6768

6869
return result
6970

@@ -87,11 +88,10 @@ fn linspace[
8788
Function that computes a series of linearly spaced values starting from "start" to "stop" with given size. Wrapper function for _linspace_serial, _linspace_parallel.
8889
8990
Raises:
90-
Error if both dtype and dtype are integers or if dtype is a float and dtype is an integer.
91+
Error if dtype is an integer.
9192
9293
Parameters:
93-
dtype: Datatype of the input values.
94-
dtype: Datatype of the output NDArray.
94+
dtype: Datatype of the output array.
9595
9696
Args:
9797
start: Start value.
@@ -220,11 +220,10 @@ fn logspace[
220220
Generate a logrithmic spaced NDArray of `num` elements between `start` and `stop`. Wrapper function for _logspace_serial, _logspace_parallel functions.
221221
222222
Raises:
223-
Error if both dtype and dtype are integers or if dtype is a float and dtype is an integer.
223+
Error if dtype is an integer.
224224
225225
Parameters:
226-
dtype: Datatype of the input values.
227-
dtype: Datatype of the output NDArray.
226+
dtype: Datatype of the output array.
228227
229228
Args:
230229
start: The starting value of the NDArray.
@@ -361,11 +360,10 @@ fn geomspace[
361360
Generate a NDArray of `num` elements between `start` and `stop` in a geometric series.
362361
363362
Raises:
364-
Error if both dtype and dtype are integers or if dtype is a float and dtype is an integer.
363+
Error if dtype is an integer.
365364
366365
Parameters:
367366
dtype: Datatype of the input values.
368-
dtype: Datatype of the output NDArray.
369367
370368
Args:
371369
start: The starting value of the NDArray.
@@ -543,8 +541,28 @@ fn full[
543541
return NDArray[dtype](shape, fill=tens_value)
544542

545543

546-
fn diagflat():
547-
pass
544+
fn diagflat[dtype: DType](inout v: NDArray[dtype], k: Int = 0) raises -> NDArray[dtype]:
545+
"""
546+
Generate a 2-D NDArray with the flattened input as the diagonal.
547+
548+
Parameters:
549+
dtype: Datatype of the NDArray elements.
550+
551+
Args:
552+
v: NDArray to be flattened and used as the diagonal.
553+
k: Diagonal offset.
554+
555+
Returns:
556+
A 2-D NDArray with the flattened input as the diagonal.
557+
"""
558+
v.reshape(v.ndshape.ndsize, 1)
559+
var n: Int= v.ndshape.ndsize + abs(k)
560+
var result: NDArray[dtype]= NDArray[dtype](n, n, random=False)
561+
562+
for i in range(n):
563+
print(n*i + i + k)
564+
result.store(n*i + i + k, v.data[i])
565+
return result
548566

549567

550568
fn tri():

numojo/core/array_manipulation_routines.mojo

Lines changed: 160 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Array manipulation routines.
33
"""
44
# ===----------------------------------------------------------------------=== #
55
# ARRAY MANIPULATION ROUTINES
6-
# Last updated: 2024-06-16
6+
# Last updated: 2024-08-03
77
# ===----------------------------------------------------------------------=== #
88

99

@@ -12,13 +12,166 @@ fn copyto():
1212
pass
1313

1414

15-
fn shape():
16-
pass
15+
fn ndim[dtype: DType](array: NDArray[dtype]) -> Int:
16+
"""
17+
Returns the number of dimensions of the NDArray.
1718
19+
Args:
20+
array: A NDArray.
1821
19-
fn reshape():
20-
pass
22+
Returns:
23+
The number of dimensions of the NDArray.
24+
"""
25+
return array.ndim
2126

2227

23-
fn ravel():
24-
pass
28+
fn shape[dtype: DType](array: NDArray[dtype]) -> NDArrayShape:
29+
"""
30+
Returns the shape of the NDArray.
31+
32+
Args:
33+
array: A NDArray.
34+
35+
Returns:
36+
The shape of the NDArray.
37+
"""
38+
return array.ndshape
39+
40+
41+
fn size[dtype: DType](array: NDArray[dtype], axis: Int) raises -> Int:
42+
"""
43+
Returns the size of the NDArray.
44+
45+
Args:
46+
array: A NDArray.
47+
axis: The axis to get the size of.
48+
49+
Returns:
50+
The size of the NDArray.
51+
"""
52+
return array.ndshape[axis]
53+
54+
55+
fn reshape[
56+
dtype: DType
57+
](
58+
inout array: NDArray[dtype], shape: VariadicList[Int], order: String = "C"
59+
) raises:
60+
"""
61+
Reshapes the NDArray to given Shape.
62+
63+
Raises:
64+
Error: If the number of elements do not match.
65+
66+
Args:
67+
array: A NDArray.
68+
shape: Variadic integers of shape.
69+
order: Order of the array - Row major `C` or Column major `F`.
70+
71+
"""
72+
var num_elements_new: Int = 1
73+
var ndim_new: Int = 0
74+
for i in shape:
75+
num_elements_new *= i
76+
ndim_new += 1
77+
78+
if array.ndshape.ndsize != num_elements_new:
79+
raise Error("Cannot reshape: Number of elements do not match.")
80+
81+
var shape_new: List[Int] = List[Int]()
82+
for i in range(ndim_new):
83+
shape_new.append(shape[i])
84+
var temp: Int = 1
85+
for j in range(i + 1, ndim_new): # temp
86+
temp *= shape[j]
87+
88+
array.ndim = ndim_new
89+
array.ndshape = NDArrayShape(shape=shape_new)
90+
array.stride = NDArrayStride(shape=shape_new, order=order)
91+
array.order = order
92+
93+
94+
fn ravel[dtype: DType](inout array: NDArray[dtype], order: String = "C") raises:
95+
"""
96+
Returns the raveled version of the NDArray.
97+
"""
98+
if array.ndim == 1:
99+
print("Array is already 1D")
100+
return
101+
else:
102+
if order == "C":
103+
reshape[dtype](array, array.ndshape.ndsize, order="C")
104+
else:
105+
reshape[dtype](array, array.ndshape.ndsize, order="F")
106+
107+
108+
fn where[
109+
dtype: DType
110+
](
111+
inout x: NDArray[dtype], scalar: SIMD[dtype, 1], mask: NDArray[DType.bool]
112+
) raises:
113+
"""
114+
Replaces elements in `x` with `scalar` where `mask` is True.
115+
116+
Parameters:
117+
dtype: DType.
118+
119+
Args:
120+
x: A NDArray.
121+
scalar: A SIMD value.
122+
mask: A NDArray.
123+
124+
"""
125+
for i in range(x.ndshape.ndsize):
126+
if mask.data[i] == True:
127+
x.data.store(i, scalar)
128+
129+
130+
# TODO: do it with vectorization
131+
fn where[
132+
dtype: DType
133+
](inout x: NDArray[dtype], y: NDArray[dtype], mask: NDArray[DType.bool]) raises:
134+
"""
135+
Replaces elements in `x` with elements from `y` where `mask` is True.
136+
137+
Raises:
138+
ShapeMismatchError: If the shapes of `x` and `y` do not match.
139+
140+
Parameters:
141+
dtype: DType.
142+
143+
Args:
144+
x: NDArray[dtype].
145+
y: NDArray[dtype].
146+
mask: NDArray[DType.bool].
147+
148+
"""
149+
if x.ndshape != y.ndshape:
150+
raise Error("Shape mismatch error: x and y must have the same shape")
151+
for i in range(x.ndshape.ndsize):
152+
if mask.data[i] == True:
153+
x.data.store(i, y.data[i])
154+
155+
156+
fn flip[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
157+
"""
158+
Flips the NDArray along the given axis.
159+
160+
Parameters:
161+
dtype: DType.
162+
163+
Args:
164+
array: A NDArray.
165+
166+
Returns:
167+
The flipped NDArray.
168+
"""
169+
if array.ndim != 1:
170+
raise Error("Flip is only supported for 1D arrays")
171+
172+
var result: NDArray[dtype] = NDArray[dtype](
173+
shape=array.ndshape, order=array.order
174+
)
175+
for i in range(array.ndshape.ndsize):
176+
result.data.store(i, array.data[array.ndshape.ndsize - i - 1])
177+
return result

0 commit comments

Comments
 (0)