Skip to content

Commit 377b802

Browse files
committed
formatted and fixed typos
1 parent 09b890e commit 377b802

17 files changed

+305
-128
lines changed

numojo/core/array_creation_routines.mojo

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,18 @@ fn arange[
5151
is_inttype[in_dtype]() and is_inttype[out_dtype]()
5252
):
5353
raise Error(
54-
"Both input and output datatypes cannot be integers. If the input is a float, the output must also be a float."
54+
"Both input and output datatypes cannot be integers. If the input"
55+
" is a float, the output must also be a float."
5556
)
5657

5758
var num: Int = ((stop - start) / step).__int__()
5859
var result: NDArray[out_dtype] = NDArray[out_dtype](
5960
NDArrayShape(num, size=num)
6061
)
6162
for idx in range(num):
62-
result.data[idx] = start.cast[out_dtype]() + step.cast[out_dtype]() * idx
63+
result.data[idx] = (
64+
start.cast[out_dtype]() + step.cast[out_dtype]() * idx
65+
)
6366

6467
return result
6568

@@ -104,7 +107,8 @@ fn linspace[
104107
is_floattype[in_dtype]() and is_inttype[out_dtype]()
105108
):
106109
raise Error(
107-
"Both input and output datatypes cannot be integers. If the input is a float, the output must also be a float."
110+
"Both input and output datatypes cannot be integers. If the input"
111+
" is a float, the output must also be a float."
108112
)
109113

110114
if parallel:
@@ -237,7 +241,8 @@ fn logspace[
237241
is_floattype[in_dtype]() and is_inttype[out_dtype]()
238242
):
239243
raise Error(
240-
"Both input and output datatypes cannot be integers. If the input is a float, the output must also be a float."
244+
"Both input and output datatypes cannot be integers. If the input"
245+
" is a float, the output must also be a float."
241246
)
242247
if parallel:
243248
return _logspace_parallel[out_dtype](
@@ -376,7 +381,8 @@ fn geomspace[
376381
is_floattype[in_dtype]() and is_inttype[out_dtype]()
377382
):
378383
raise Error(
379-
"Both input and output datatypes cannot be integers. If the input is a float, the output must also be a float."
384+
"Both input and output datatypes cannot be integers. If the input"
385+
" is a float, the output must also be a float."
380386
)
381387

382388
var a: Scalar[out_dtype] = start.cast[out_dtype]()

numojo/core/array_manipulation_routines.mojo

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
22
# ===----------------------------------------------------------------------=== #
33
# ARRAY MANIPULATION ROUTINES
4-
# Last updated: 2024-07-21
4+
# Last updated: 2024-08-03
55
# ===----------------------------------------------------------------------=== #
66
"""
77

88

99
fn copyto():
1010
pass
1111

12+
1213
fn ndim[dtype: DType](array: NDArray[dtype]) -> Int:
1314
"""
1415
Returns the number of dimensions of the NDArray.
@@ -34,6 +35,7 @@ fn shape[dtype: DType](array: NDArray[dtype]) -> NDArrayShape:
3435
"""
3536
return array.ndshape
3637

38+
3739
fn size[dtype: DType](array: NDArray[dtype], axis: Int) raises -> Int:
3840
"""
3941
Returns the size of the NDArray.
@@ -47,7 +49,12 @@ fn size[dtype: DType](array: NDArray[dtype], axis: Int) raises -> Int:
4749
"""
4850
return array.ndshape[axis]
4951

50-
fn reshape[dtype: DType](inout array: NDArray[dtype], shape: VariadicList[Int], order: String = "C") raises:
52+
53+
fn reshape[
54+
dtype: DType
55+
](
56+
inout array: NDArray[dtype], shape: VariadicList[Int], order: String = "C"
57+
) raises:
5158
"""
5259
Reshapes the NDArray to given Shape.
5360
@@ -58,7 +65,7 @@ fn reshape[dtype: DType](inout array: NDArray[dtype], shape: VariadicList[Int],
5865
array: A NDArray.
5966
shape: Variadic integers of shape.
6067
order: Order of the array - Row major `C` or Column major `F`.
61-
68+
6269
"""
6370
var num_elements_new: Int = 1
6471
var ndim_new: Int = 0
@@ -81,26 +88,32 @@ fn reshape[dtype: DType](inout array: NDArray[dtype], shape: VariadicList[Int],
8188
array.stride = NDArrayStride(shape=shape_new, order=order)
8289
array.order = order
8390

91+
8492
fn ravel[dtype: DType](inout array: NDArray[dtype], order: String = "C") raises:
8593
"""
8694
Returns the raveled version of the NDArray.
8795
"""
8896
if array.ndim == 1:
8997
print("Array is already 1D")
90-
return
98+
return
9199
else:
92100
if order == "C":
93101
reshape[dtype](array, array.ndshape._size, order="C")
94102
else:
95103
reshape[dtype](array, array.ndshape._size, order="F")
96104

97-
fn where[dtype: DType](inout x: NDArray[dtype], scalar: SIMD[dtype, 1], mask:NDArray[DType.bool]) raises:
105+
106+
fn where[
107+
dtype: DType
108+
](
109+
inout x: NDArray[dtype], scalar: SIMD[dtype, 1], mask: NDArray[DType.bool]
110+
) raises:
98111
"""
99112
Replaces elements in `x` with `scalar` where `mask` is True.
100113
101114
Parameters:
102115
dtype: DType.
103-
116+
104117
Args:
105118
x: A NDArray.
106119
scalar: A SIMD value.
@@ -111,8 +124,11 @@ fn where[dtype: DType](inout x: NDArray[dtype], scalar: SIMD[dtype, 1], mask:NDA
111124
if mask.data[i] == True:
112125
x.data.store(i, scalar)
113126

127+
114128
# TODO: do it with vectorization
115-
fn where[dtype: DType](inout x: NDArray[dtype], y: NDArray[dtype], mask:NDArray[DType.bool]) raises:
129+
fn where[
130+
dtype: DType
131+
](inout x: NDArray[dtype], y: NDArray[dtype], mask: NDArray[DType.bool]) raises:
116132
"""
117133
Replaces elements in `x` with elements from `y` where `mask` is True.
118134
@@ -121,7 +137,7 @@ fn where[dtype: DType](inout x: NDArray[dtype], y: NDArray[dtype], mask:NDArray[
121137
122138
Parameters:
123139
dtype: DType.
124-
140+
125141
Args:
126142
x: NDArray[dtype].
127143
y: NDArray[dtype].
@@ -135,13 +151,13 @@ fn where[dtype: DType](inout x: NDArray[dtype], y: NDArray[dtype], mask:NDArray[
135151
x.data.store(i, y.data[i])
136152

137153

138-
fn flip[dtype: DType](inout array: NDArray[dtype]) raises -> NDArray[dtype]:
154+
fn flip[dtype: DType](array: NDArray[dtype]) raises -> NDArray[dtype]:
139155
"""
140156
Flips the NDArray along the given axis.
141157
142158
Parameters:
143159
dtype: DType.
144-
160+
145161
Args:
146162
array: A NDArray.
147163
@@ -151,8 +167,9 @@ fn flip[dtype: DType](inout array: NDArray[dtype]) raises -> NDArray[dtype]:
151167
if array.ndim != 1:
152168
raise Error("Flip is only supported for 1D arrays")
153169

154-
var result: NDArray[dtype] = NDArray[dtype](shape=array.ndshape, order=array.order)
170+
var result: NDArray[dtype] = NDArray[dtype](
171+
shape=array.ndshape, order=array.order
172+
)
155173
for i in range(array.ndshape._size):
156174
result.data.store(i, array.data[array.ndshape._size - i - 1])
157-
158175
return result

numojo/core/ndarray.mojo

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ from .ndarray_utils import (
4949
from ..math.math_funcs import Vectorized
5050
from .utility_funcs import is_inttype
5151
from ..math.linalg.matmul import matmul_parallelized
52-
from .array_manipulation_routines import reshape
52+
from .array_manipulation_routines import reshape
5353

5454

5555
@register_passable("trivial")
@@ -968,7 +968,7 @@ struct NDArray[dtype: DType = DType.float32](
968968
if self.ndim == 1:
969969
narr.ndim = 0
970970
narr.ndshape._shape[0] = 0
971-
971+
972972
return narr
973973

974974
fn _adjust_slice_(self, inout span: Slice, dim: Int):
@@ -2435,7 +2435,6 @@ struct NDArray[dtype: DType = DType.float32](
24352435
var s: VariadicList[Int] = shape
24362436
reshape[dtype](self, s, order=order)
24372437

2438-
24392438
fn unsafe_ptr(self) -> DTypePointer[dtype, 0]:
24402439
return self.data
24412440

numojo/core/ndarray_utils.mojo

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from .ndarray import NDArray, NDArrayShape, NDArrayStride
1010

1111
# TODO: there's some problem with using narr[idx] in traverse function, Make sure to correct this before v0.1
1212

13+
1314
fn _get_index(indices: List[Int], weights: NDArrayShape) raises -> Int:
1415
"""
1516
Get the index of a multi-dimensional array from a list of indices and weights.
@@ -192,10 +193,10 @@ fn to_numpy[dtype: DType](array: NDArray[dtype]) raises -> PythonObject:
192193
Convert a NDArray to a numpy array.
193194
194195
Example:
195-
```console
196+
```console
196197
var arr = NDArray[DType.float32](3, 3, 3)
197-
var np_arr = to_numpy(arr)
198-
var np_arr1 = arr.to_numpy()
198+
var np_arr = to_numpy(arr)
199+
var np_arr1 = arr.to_numpy()
199200
```
200201
201202
Parameters:

numojo/core/sort.mojo

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,11 @@ fn _argsort_partition(
224224
"""
225225

226226
var pivot_value = ndarray.get_scalar(pivot_index)
227-
227+
228228
var _value_at_pivot = ndarray.get_scalar(pivot_index)
229229
ndarray.__setitem__(pivot_index, ndarray.get_scalar(right))
230230
ndarray.__setitem__(right, _value_at_pivot)
231-
231+
232232
var _value_at_pivot_index = idx_array.get_scalar(pivot_index)
233233
idx_array.__setitem__(pivot_index, idx_array.get_scalar(right))
234234
idx_array.__setitem__(right, _value_at_pivot_index)
@@ -240,7 +240,7 @@ fn _argsort_partition(
240240
var _value_at_store = ndarray.get_scalar(store_index)
241241
ndarray.__setitem__(store_index, ndarray.get_scalar(i))
242242
ndarray.__setitem__(i, _value_at_store)
243-
243+
244244
var _value_at_store_index = idx_array.get_scalar(store_index)
245245
idx_array.__setitem__(store_index, idx_array.get_scalar(i))
246246
idx_array.__setitem__(i, _value_at_store_index)

numojo/core/utility_funcs.mojo

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ fn is_floattype[dtype: DType]() -> Bool:
6464
return True
6565
return False
6666

67+
6768
fn is_floattype(dtype: DType) -> Bool:
6869
"""
6970
Check if the given dtype is a floating point type at run time.
@@ -80,4 +81,4 @@ fn is_floattype(dtype: DType) -> Bool:
8081
or dtype == DType.float64
8182
):
8283
return True
83-
return False
84+
return False

numojo/math/__init__.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ from .math_funcs import (
1212
VectorizedVerbose,
1313
VectorizedParallelizedNWorkers,
1414
)
15-
from .interpolate import *
15+
from .interpolate import *

0 commit comments

Comments
 (0)