Skip to content

Commit b9d15c9

Browse files
Merge pull request #85 from MadAlex1997/remove-out-dtype
Removing out_dtype, and minor doc fixes
2 parents 4d18575 + 3115b45 commit b9d15c9

File tree

10 files changed

+190
-233
lines changed

10 files changed

+190
-233
lines changed

new_tests/test_sort.mojo

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numojo as nm
2+
from time import now
3+
from python import Python, PythonObject
4+
from utils_for_test import check, check_is_close
5+
6+
def test_sort_1d():
7+
arr = nm.NDArray(25,random=True)
8+
var np = Python.import_module("numpy")
9+
arr_sorted = arr.sort()
10+
np_arr_sorted = np.sort(arr.to_numpy())
11+
return check(arr_sorted,np_arr_sorted, "quick sort is broken")
12+
13+
# ND sorting currently works differently than numpy which has an on axis
14+
15+
# def test_sort_2d():
16+
# arr = nm.NDArray(5,5,random=True)
17+
# var np = Python.import_module("numpy")
18+
# arr_sorted = arr.sort()
19+
# print(arr_sorted)
20+
# np_arr_sorted = np.sort(arr.to_numpy())
21+
# print(np_arr_sorted)
22+
# return check(arr_sorted,np_arr_sorted, "quick sort is broken")
23+
24+
# def main():
25+
# test_sort_1d()
26+
# # test_sort_2d()

numojo/core/array_creation_routines.mojo

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ fn arange[
6363
)
6464
for idx in range(num):
6565
result.data[idx] = (
66-
start.cast[dtype]() + step.cast[dtype]() * idx
66+
start + step * idx
6767
)
6868

6969
return result
@@ -113,11 +113,11 @@ fn linspace[
113113
constrained[not dtype.is_integral()]()
114114
if parallel:
115115
return _linspace_parallel[dtype](
116-
start.cast[dtype](), stop.cast[dtype](), num, endpoint
116+
start, stop, num, endpoint
117117
)
118118
else:
119119
return _linspace_serial[dtype](
120-
start.cast[dtype](), stop.cast[dtype](), num, endpoint
120+
start, stop, num, endpoint
121121
)
122122

123123

@@ -245,18 +245,18 @@ fn logspace[
245245
# )
246246
if parallel:
247247
return _logspace_parallel[dtype](
248-
start.cast[dtype](),
249-
stop.cast[dtype](),
248+
start,
249+
stop,
250250
num,
251-
base.cast[dtype](),
251+
base,
252252
endpoint,
253253
)
254254
else:
255255
return _logspace_serial[dtype](
256-
start.cast[dtype](),
257-
stop.cast[dtype](),
256+
start,
257+
stop,
258258
num,
259-
base.cast[dtype](),
259+
base,
260260
endpoint,
261261
)
262262

@@ -382,22 +382,22 @@ fn geomspace[
382382
# "Both input and output datatypes cannot be integers. If the input is a float, the output must also be a float."
383383
# )
384384

385-
var a: Scalar[dtype] = start.cast[dtype]()
385+
var a: Scalar[dtype] = start
386386

387387
if endpoint:
388388
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(num))
389389
var r: Scalar[dtype] = (
390-
stop.cast[dtype]() / start.cast[dtype]()
391-
) ** (1 / (num - 1)).cast[dtype]()
390+
stop / start
391+
) ** (1 / (num - 1))
392392
for i in range(num):
393393
result.data[i] = a * r**i
394394
return result
395395

396396
else:
397397
var result: NDArray[dtype] = NDArray[dtype](NDArrayShape(num))
398398
var r: Scalar[dtype] = (
399-
stop.cast[dtype]() / start.cast[dtype]()
400-
) ** (1 / (num)).cast[dtype]()
399+
stop / start
400+
) ** (1 / (num))
401401
for i in range(num):
402402
result.data[i] = a * r**i
403403
return result
@@ -537,7 +537,7 @@ fn full[
537537
Returns:
538538
A NDArray of `dtype` with given `shape`.
539539
"""
540-
var tens_value: SIMD[dtype, 1] = SIMD[dtype, 1](fill_value).cast[dtype]()
540+
var tens_value: SIMD[dtype, 1] = SIMD[dtype, 1](fill_value)
541541
return NDArray[dtype](shape, fill=tens_value)
542542

543543

numojo/core/datatypes.mojo

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -31,46 +31,3 @@ alias f32 = DType.float32
3131
"""Data type alias for DType.float32"""
3232
alias f64 = DType.float64
3333
"""Data type alias for DType.float64"""
34-
35-
36-
fn cvtdtype[
37-
in_dtype: DType, out_dtype: DType, width: Int = 1
38-
](value: SIMD[in_dtype, width]) -> SIMD[out_dtype, width]:
39-
"""
40-
Converts datatype of a value from in_dtype to out_dtype at run time.
41-
42-
Parameters:
43-
in_dtype: The input datatype.
44-
out_dtype: The output dataytpe.
45-
width: The width of the SIMD vector.
46-
47-
Args:
48-
value: The SIMD value to be converted.
49-
50-
Returns:
51-
The `value` with its dtype cast as out_dtype.
52-
53-
"""
54-
return value.cast[out_dtype]()
55-
56-
57-
fn cvtdtype[
58-
in_dtype: DType,
59-
out_dtype: DType,
60-
width: Int = 1,
61-
value: SIMD[in_dtype, width] = SIMD[in_dtype](),
62-
]() -> SIMD[out_dtype, width]:
63-
"""
64-
Converts datatype of a value from in_dtype to out_dtype at compile time.
65-
66-
Parameters:
67-
in_dtype: The input datatype.
68-
out_dtype: The output dataytpe.
69-
width: The width of the SIMD vector.
70-
value: The SIMD value to be converted.
71-
72-
Returns:
73-
The `value` with its dtype cast as out_dtype.
74-
75-
"""
76-
return value.cast[out_dtype]()

numojo/core/sort.mojo

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ fn quick_sort[
164164

165165

166166
fn binary_sort[
167-
in_dtype: DType, out_dtype: DType = DType.float64
168-
](array: NDArray[in_dtype]) raises -> NDArray[out_dtype]:
167+
dtype: DType = DType.float64
168+
](array: NDArray[ dtype]) raises -> NDArray[dtype]:
169169
"""
170170
Binary sorting of NDArray.
171171
@@ -177,24 +177,27 @@ fn binary_sort[
177177
```
178178
179179
Parameters:
180-
in_dtype: The input element type.
181-
out_dtype: The output element type.
180+
dtype: The element type.
182181
183182
Args:
184183
array: A NDArray.
185184
186185
Returns:
187-
The sorted NDArray of type `out_dtype`.
186+
The sorted NDArray of type `dtype`.
188187
"""
189-
var result: NDArray[out_dtype] = NDArray[out_dtype](array.shape())
188+
@parameter
189+
if dtype != array.dtype:
190+
alias dtype = array.dtype
191+
192+
var result: NDArray[dtype] = NDArray[dtype](array.shape())
190193
for i in range(array.ndshape.ndsize):
191-
result.store(i, array.get_scalar(i).cast[out_dtype]())
194+
result.store(i, array.get_scalar(i).cast[dtype]())
192195

193196
var n = array.num_elements()
194197
for end in range(n, 1, -1):
195198
for i in range(1, end):
196199
if result[i - 1] > result[i]:
197-
var temp: Scalar[out_dtype] = result.get_scalar(i - 1)
200+
var temp: Scalar[dtype] = result.get_scalar(i - 1)
198201
result[i - 1] = result[i]
199202
result.store(i, temp)
200203
return result

numojo/math/arithmetic.mojo

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,13 @@ fn sub[
198198

199199

200200
fn diff[
201-
in_dtype: DType, out_dtype: DType = in_dtype
202-
](array: NDArray[in_dtype], n: Int) raises -> NDArray[out_dtype]:
201+
dtype: DType = DType.float64
202+
](array: NDArray[ dtype], n: Int) raises -> NDArray[dtype]:
203203
"""
204204
Compute the n-th order difference of the input array.
205205
206206
Parameters:
207-
in_dtype: Input data type.
208-
out_dtype: Output data type, defaults to float32.
207+
dtype: The element type.
209208
210209
Args:
211210
array: A array.
@@ -215,19 +214,19 @@ fn diff[
215214
The n-th order difference of the input array.
216215
"""
217216

218-
var array1: NDArray[out_dtype] = NDArray[out_dtype](
217+
var array1: NDArray[dtype] = NDArray[dtype](
219218
NDArrayShape(array.num_elements())
220219
)
221220
for i in range(array.num_elements()):
222-
array1.store(i, array.get_scalar(i).cast[out_dtype]())
221+
array1.store(i, array.get_scalar(i))
223222

224223
for num in range(n):
225-
var result: NDArray[out_dtype] = NDArray[out_dtype](
224+
var result: NDArray[dtype] = NDArray[dtype](
226225
NDArrayShape(array.num_elements() - (num + 1))
227226
)
228227
for i in range(array1.num_elements() - 1):
229228
result.store(
230-
i, (array1.load[1](i + 1) - array1.load[1](i)).cast[out_dtype]()
229+
i, (array1.load[1](i + 1) - array1.load[1](i))
231230
)
232231
array1 = result
233232
return array1

numojo/math/calculus/differentiation.mojo

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,16 @@ fn gradient[
3434
Returns:
3535
The integral of y over x using the trapezoidal rule.
3636
"""
37-
# var result: NDArray[dtype] = NDArray[dtype](x.shape(), random=False)
38-
# var space: NDArray[dtype] = NDArray[dtype](x.shape(), random=False)
39-
# if spacing.isa[NDArray[dtype]]():
40-
# for i in range(x.num_elements()):
41-
# space[i] = spacing._get_ptr[NDArray[dtype]]()[][i].cast[dtype]()
42-
43-
# elif spacing.isa[Scalar[dtype]]():
44-
# var int: Scalar[dtype] = spacing._get_ptr[Scalar[dtype]]()[]
45-
# space = numojo.arange[dtype, dtype](1, x.num_elements(), step=int)
46-
37+
4738
var result: NDArray[dtype] = NDArray[dtype](x.shape(), random=False)
4839
var space: NDArray[dtype] = core.arange[dtype](
49-
1, x.num_elements() + 1, step=spacing.cast[dtype]()
40+
1, x.num_elements() + 1, step=spacing
5041
)
5142
var hu: Scalar[dtype] = space.get_scalar(1)
5243
var hd: Scalar[dtype] = space.get_scalar(0)
5344
result.store(
5445
0,
55-
(x.get_scalar(1).cast[dtype]() - x.get_scalar(0).cast[dtype]())
46+
(x.get_scalar(1) - x.get_scalar(0))
5647
/ (hu - hd),
5748
)
5849

@@ -61,8 +52,8 @@ fn gradient[
6152
result.store(
6253
x.num_elements() - 1,
6354
(
64-
x.get_scalar(x.num_elements() - 1).cast[dtype]()
65-
- x.get_scalar(x.num_elements() - 2).cast[dtype]()
55+
x.get_scalar(x.num_elements() - 1)
56+
- x.get_scalar(x.num_elements() - 2)
6657
)
6758
/ (hu - hd),
6859
)
@@ -75,9 +66,9 @@ fn gradient[
7566
i - 1
7667
)
7768
var fi: Scalar[dtype] = (
78-
hd**2 * x.get_scalar(i + 1).cast[dtype]()
79-
+ (hu**2 - hd**2) * x.get_scalar(i).cast[dtype]()
80-
- hu**2 * x.get_scalar(i - 1).cast[dtype]()
69+
hd**2 * x.get_scalar(i + 1)
70+
+ (hu**2 - hd**2) * x.get_scalar(i)
71+
- hu**2 * x.get_scalar(i - 1)
8172
) / (hu * hd * (hu + hd))
8273
result.store(i, fi)
8374

numojo/math/calculus/integral.mojo

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@ from algorithm import Static2DTileUnitFunc as Tile2DFunc
1313

1414
# naive loop implementation, optimize later
1515
fn trapz[
16-
in_dtype: DType, out_dtype: DType = DType.float32
17-
](y: NDArray[in_dtype], x: NDArray[in_dtype]) raises -> SIMD[out_dtype, 1]:
16+
dtype: DType = DType.float64
17+
](y: NDArray[ dtype], x: NDArray[ dtype]) raises -> SIMD[dtype, 1]:
1818
"""
1919
Compute the integral of y over x using the trapezoidal rule.
2020
2121
Parameters:
22-
in_dtype: Input data type.
23-
out_dtype: Output data type, defaults to float32.
22+
dtype: The element type.
2423
2524
Args:
2625
y: An array.
@@ -37,16 +36,16 @@ fn trapz[
3736
raise Error("x and y must have the same shape")
3837

3938
# move this check to compile time using constrained?
40-
if is_inttype[in_dtype]() and not is_floattype[out_dtype]():
39+
if is_inttype[ dtype]() and not is_floattype[dtype]():
4140
raise Error(
4241
"output dtype `Fdtype` must be a floating-point type if input dtype"
4342
" `Idtype` is not a floating-point type"
4443
)
4544

46-
var integral: SIMD[out_dtype] = 0.0
45+
var integral: SIMD[dtype] = 0.0
4746
for i in range(x.num_elements() - 1):
48-
var temp = (x.get_scalar(i + 1) - x.get_scalar(i)).cast[out_dtype]() * (
47+
var temp = (x.get_scalar(i + 1) - x.get_scalar(i)) * (
4948
y.get_scalar(i) + y.get_scalar(i + 1)
50-
).cast[out_dtype]() / 2.0
49+
) / 2.0
5150
integral += temp
5251
return integral

numojo/math/linalg/linalg.mojo

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@ from algorithm import Static2DTileUnitFunc as Tile2DFunc
1515

1616

1717
fn cross[
18-
in_dtype: DType, out_dtype: DType = DType.float32
19-
](array1: NDArray[in_dtype], array2: NDArray[in_dtype]) raises -> NDArray[
20-
out_dtype
18+
dtype: DType = DType.float64
19+
](array1: NDArray[ dtype], array2: NDArray[ dtype]) raises -> NDArray[
20+
dtype
2121
]:
2222
"""
2323
Compute the cross product of two arrays.
2424
2525
Parameters
26-
in_dtype: Input data type.
27-
out_dtype: Output data type, defaults to float32.
26+
dtype: The element type.
2827
2928
Args:
3029
array1: A array.
@@ -38,27 +37,27 @@ fn cross[
3837
"""
3938

4039
if array1.ndshape.ndlen == array2.ndshape.ndlen == 3:
41-
var array3: NDArray[out_dtype] = NDArray[out_dtype](NDArrayShape(3))
40+
var array3: NDArray[dtype] = NDArray[dtype](NDArrayShape(3))
4241
array3.store(
4342
0,
4443
(
4544
array1.get_scalar(1) * array2.get_scalar(2)
4645
- array1.get_scalar(2) * array2.get_scalar(1)
47-
).cast[out_dtype](),
46+
),
4847
)
4948
array3.store(
5049
1,
5150
(
5251
array1.get_scalar(2) * array2.get_scalar(0)
5352
- array1.get_scalar(0) * array2.get_scalar(2)
54-
).cast[out_dtype](),
53+
),
5554
)
5655
array3.store(
5756
2,
5857
(
5958
array1.get_scalar(0) * array2.get_scalar(1)
6059
- array1.get_scalar(1) * array2.get_scalar(0)
61-
).cast[out_dtype](),
60+
),
6261
)
6362
return array3
6463
else:

0 commit comments

Comments
 (0)