Skip to content

Commit c5af6b9

Browse files
authored
Merge pull request #88 from mmenendezg/experimental
Add NDArray initialization with random values in an specified interval
2 parents b9d15c9 + 2403a56 commit c5af6b9

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

numojo/core/ndarray.mojo

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Implements N-Dimensional Array
2020
"""
2121

2222
from builtin.type_aliases import AnyLifetime
23-
from random import rand
23+
from random import rand, random_si64, random_float64
2424
from builtin.math import pow
2525
from builtin.bool import all as allb
2626
from builtin.bool import any as anyb
@@ -922,6 +922,112 @@ struct NDArray[dtype: DType = DType.float64](
922922
for i in range(self.ndshape.ndsize):
923923
self.data[i] = data[i]
924924

925+
@always_inline("nodebug")
926+
fn __init__(
927+
inout self,
928+
*shape: Int,
929+
min: Scalar[dtype],
930+
max: Scalar[dtype],
931+
order: String = "C",
932+
) raises:
933+
"""
934+
NDArray initialization for variadic shape with random values between min and max.
935+
936+
Args:
937+
shape: Variadic shape.
938+
min: Minimum value for the NDArray.
939+
max: Maximum value for the NDArray.
940+
order: Memory order C or F.
941+
942+
Example:
943+
```mojo
944+
import numojo as nm
945+
fn main() raises:
946+
var A = nm.NDArray[DType.float16](2, 2, min=0.0, max=10.0)
947+
print(A)
948+
```
949+
A is an array with shape 2 x 2 and randomly values between 0 and 10.
950+
The output goes as follows.
951+
952+
```console
953+
[[ 6.046875 6.98046875 ]
954+
[ 6.6484375 1.736328125 ]]
955+
2-D array Shape: [2, 2] DType: float16
956+
```
957+
"""
958+
self.ndim = shape.__len__()
959+
self.ndshape = NDArrayShape(shape)
960+
self.stride = NDArrayStride(shape, offset=0, order=order)
961+
self.coefficient = NDArrayStride(shape, offset=0, order=order)
962+
self.datatype = dtype
963+
self.order = order
964+
self.data = DTypePointer[dtype].alloc(self.ndshape.ndsize)
965+
if dtype.is_floating_point():
966+
for i in range(self.ndshape.ndsize):
967+
self.data.store(
968+
i,
969+
random_float64(min.cast[DType.float64](), max.cast[DType.float64]()).cast[dtype]()
970+
)
971+
elif dtype.is_integral():
972+
for i in range(self.ndshape.ndsize):
973+
self.data.store(
974+
i,
975+
random_si64(int(min), int(max)).cast[dtype]()
976+
)
977+
978+
@always_inline("nodebug")
979+
fn __init__(
980+
inout self,
981+
shape: List[Int],
982+
min: Scalar[dtype],
983+
max: Scalar[dtype],
984+
order: String = "C",
985+
) raises:
986+
"""
987+
NDArray initialization for list shape with random values between min and max.
988+
989+
Args:
990+
shape: List of shape.
991+
min: Minimum value for the NDArray.
992+
max: Maximum value for the NDArray.
993+
order: Memory order C or F.
994+
995+
Example:
996+
```mojo
997+
import numojo as nm
998+
fn main() raises:
999+
var A = nm.NDArray[DType.float16](List[Int](2, 2), min=0.0, max=10.0)
1000+
print(A)
1001+
```
1002+
A is an array with shape 2 x 2 and randomly values between 0 and 10.
1003+
The output goes as follows.
1004+
1005+
```console
1006+
[[ 6.046875 6.98046875 ]
1007+
[ 6.6484375 1.736328125 ]]
1008+
2-D array Shape: [2, 2] DType: float16
1009+
```
1010+
"""
1011+
self.ndim = shape.__len__()
1012+
self.ndshape = NDArrayShape(shape)
1013+
self.stride = NDArrayStride(shape, offset=0, order=order)
1014+
self.coefficient = NDArrayStride(shape, offset=0, order=order)
1015+
self.datatype = dtype
1016+
self.order = order
1017+
self.data = DTypePointer[dtype].alloc(self.ndshape.ndsize)
1018+
if dtype.is_floating_point():
1019+
for i in range(self.ndshape.ndsize):
1020+
self.data.store(
1021+
i,
1022+
random_float64(min.cast[DType.float64](), max.cast[DType.float64]()).cast[dtype]()
1023+
)
1024+
elif dtype.is_integral():
1025+
for i in range(self.ndshape.ndsize):
1026+
self.data.store(
1027+
i,
1028+
random_si64(int(min), int(max)).cast[dtype]()
1029+
)
1030+
9251031
fn __init__(
9261032
inout self,
9271033
text: String,

0 commit comments

Comments
 (0)