Skip to content

Commit bfa5825

Browse files
authored
Merge pull request #133 from shivasankarka/main
Getter setters
2 parents 84ec78d + d1370d7 commit bfa5825

File tree

3 files changed

+213
-238
lines changed

3 files changed

+213
-238
lines changed

numojo/core/ndarray.mojo

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -478,13 +478,12 @@ struct NDArray[dtype: DType = DType.float64](
478478
var size_at_dim: Int = self.ndshape[i]
479479
slice_list.append(Slice(0, size_at_dim))
480480

481-
# self.__setitem__(slice_list=slice_list, val=val)
482481
var n_slices: Int = len(slice_list)
483482
var ndims: Int = 0
484483
var count: Int = 0
485484
var spec: List[Int] = List[Int]()
486485
for i in range(n_slices):
487-
self._adjust_slice_(slice_list[i], self.ndshape[i])
486+
# self._adjust_slice_(slice_list[i], self.ndshape[i])
488487
if (
489488
slice_list[i].start.value() >= self.ndshape[i]
490489
or slice_list[i].end.value() > self.ndshape[i]
@@ -611,43 +610,11 @@ struct NDArray[dtype: DType = DType.float64](
611610
Example:
612611
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
613612
"""
614-
print("slices: ", slices[0], slices[1], slices[2])
615613
var n_slices: Int = len(slices)
616614
var ndims: Int = 0
617615
var count: Int = 0
618616
var spec: List[Int] = List[Int]()
619-
var slice_list: List[Slice] = List[Slice]()
620-
for i in range(n_slices):
621-
var start: Int = 0
622-
var end: Int = 0
623-
if slices[i].start is None and slices[i].end is None:
624-
start = 0
625-
end = self.ndshape[i]
626-
temp = Slice(
627-
start=Optional(start),
628-
end=Optional(end),
629-
step=Optional(slices[i].step),
630-
)
631-
slice_list.append(temp)
632-
if slices[i].start is None and slices[i].end is not None:
633-
start = 0
634-
temp = Slice(
635-
start=Optional(start),
636-
end=Optional(slices[i].end.value()),
637-
step=Optional(slices[i].step),
638-
)
639-
slice_list.append(temp)
640-
if slices[i].start is not None and slices[i].end is None:
641-
end = self.ndshape[i]
642-
temp = Slice(
643-
start=Optional(slices[i].start.value()),
644-
end=Optional(end),
645-
step=Optional(slices[i].step),
646-
)
647-
slice_list.append(temp)
648-
if slices[i].start is not None and slices[i].end is not None:
649-
slice_list.append(slices[i])
650-
617+
var slice_list: List[Slice] = self._adjust_slice_(slices)
651618
for i in range(n_slices):
652619
if (
653620
slice_list[i].start.value() >= self.ndshape[i]
@@ -867,24 +834,50 @@ struct NDArray[dtype: DType = DType.float64](
867834
var idx: Int = _get_index(index, self.coefficient)
868835
return self.data.load[width=1](idx)
869836

870-
fn _adjust_slice_(self, inout span: Slice, dim: Int):
837+
fn _adjust_slice_(self, slice_list: List[Slice]) raises -> List[Slice]:
871838
"""
872839
Adjusts the slice values to lie within 0 and dim.
873840
"""
874-
if span.start or span.end:
875-
var start = int(span.start.value())
876-
var end = int(span.end.value())
877-
if start < 0:
878-
start = dim + start
879-
if not span.end:
880-
end = dim
881-
elif end < 0:
882-
end = dim + end
883-
if end > dim:
884-
end = dim
885-
if end < start:
886-
start = 0
887-
end = 0
841+
var n_slices: Int = slice_list.__len__()
842+
var slices = List[Slice]()
843+
for i in range(n_slices):
844+
if i >= self.ndim:
845+
raise Error("Error: Number of slices exceeds array dimensions")
846+
847+
var start: Int = 0
848+
var end: Int = self.ndshape[i]
849+
var step: Int = 1
850+
if slice_list[i].start is not None:
851+
start = slice_list[i].start.value()
852+
if start < 0:
853+
# start += self.ndshape[i]
854+
raise Error(
855+
"Error: Negative indexing in slices not supported"
856+
" currently"
857+
)
858+
859+
if slice_list[i].end is not None:
860+
end = slice_list[i].end.value()
861+
if end < 0:
862+
# end += self.ndshape[i] + 1
863+
raise Error(
864+
"Error: Negative indexing in slices not supported"
865+
" currently"
866+
)
867+
868+
step = slice_list[i].step
869+
if step == 0:
870+
raise Error("Error: Slice step cannot be zero")
871+
872+
slices.append(
873+
Slice(
874+
start=Optional(start),
875+
end=Optional(end),
876+
step=Optional(step),
877+
)
878+
)
879+
880+
return slices^
888881

889882
fn __getitem__(self, owned *slices: Slice) raises -> Self:
890883
"""
@@ -908,23 +901,24 @@ struct NDArray[dtype: DType = DType.float64](
908901
var narr: Self = self[slice_list]
909902
return narr
910903

911-
fn __getitem__(self, owned slices: List[Slice]) raises -> Self:
904+
fn __getitem__(self, owned slice_list: List[Slice]) raises -> Self:
912905
"""
913906
Retreive slices of an array from list of slices.
914907
915908
Example:
916909
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
917910
"""
918911

919-
var n_slices: Int = slices.__len__()
912+
var n_slices: Int = slice_list.__len__()
920913
if n_slices > self.ndim or n_slices < self.ndim:
921914
raise Error("Error: No of slices do not match shape")
922915

923916
var ndims: Int = 0
924917
var spec: List[Int] = List[Int]()
925918
var count: Int = 0
919+
920+
var slices: List[Slice] = self._adjust_slice_(slice_list)
926921
for i in range(slices.__len__()):
927-
self._adjust_slice_(slices[i], self.ndshape[i])
928922
if (
929923
slices[i].start.value() >= self.ndshape[i]
930924
or slices[i].end.value() > self.ndshape[i]

test.mojo

Lines changed: 48 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,15 @@ fn test_bool_masks2() raises:
218218
print(temp3.ndshape, temp3.stride, temp3.ndshape.ndsize)
219219

220220

221-
fn test_creation_routines() raises:
222-
var x = linspace[numojo.f32](0.0, 60.0, 60)
223-
var y = ones[numojo.f32](3, 2)
224-
var z = logspace[numojo.f32](-3, 0, 60)
225-
var w = arange[f32](0.0, 24.0, step=1)
226-
print(x)
227-
print(y)
228-
print(z)
229-
print(w)
221+
# fn test_creation_routines() raises:
222+
# var x = linspace[numojo.f32](0.0, 60.0, 60)
223+
# var y = ones[numojo.f32](shape(3, 2))
224+
# var z = logspace[numojo.f32](-3, 0, 60)
225+
# var w = arange[f32](0.0, 24.0, step=1)
226+
# print(x)
227+
# print(y)
228+
# print(z)
229+
# print(w)
230230

231231

232232
fn test_slicing() raises:
@@ -267,6 +267,33 @@ fn test_slicing() raises:
267267
print(slicedy3)
268268
# print("Time taken: ", (time.now() - start)/1e9/10)
269269

270+
# var np = Python.import_module("numpy")
271+
# y = nm.arange[nm.f32](0.0, 24.0, step=1)
272+
# y.reshape(2, 3, 4, order="C")
273+
# np_y = np.arange(0, 24, dtype=np.float32).reshape(2, 3, 4, order="C")
274+
# print(y)
275+
# print(np_y)
276+
# print()
277+
# # Test slicing
278+
# slicedy = y[:, :, 1:2]
279+
# print("slicedy: ", slicedy)
280+
# np_slicedy = np.take(
281+
# np.take(
282+
# np.take(np_y, np.arange(0, 2), axis=0), np.arange(0, 3), axis=1
283+
# ),
284+
# np.arange(1, 2),
285+
# axis=2,
286+
# )
287+
# print("np_slicedy: ", np_slicedy)
288+
# np_slicedy = np.squeeze(
289+
# np_slicedy, axis=2
290+
# ) # Remove the dimension with size 1
291+
# var np_arr = slicedy.to_numpy()
292+
# print()
293+
# print(np_arr)
294+
# print(np_slicedy)
295+
# print(np.all(np.equal(np_arr, np_slicedy)))
296+
270297

271298
fn test_rand_funcs[
272299
dtype: DType = DType.float64
@@ -331,18 +358,19 @@ def test_solve():
331358

332359
fn test_setter() raises:
333360
print("Testing setter")
334-
# var A = NDArray[i16](2, 3, 2, fill=Scalar[i16](1))
335-
# var B = NDArray[i16](3, 2, fill=Scalar[i16](2))
336-
# A[0] = B
337-
# print(A)
361+
var A = nm.full[i16](3, 3, 3, fill_value=1)
362+
var B = nm.full[i16](3, 3, fill_value=2)
363+
A[0] = B
364+
print(A)
338365

339-
var A = ndarray[i16](3, 3, 3, fill=Scalar[i16](1))
340-
print("1: ", A)
341-
var D = nm.random.rand[i16](3, 3, min=0, max=100)
342-
A[1] = D # sets the elements of A[1:2, :, :] with the array `D`
343-
print("2: ", A)
344-
A[:, 0:1, :] = D # sets the elements of A[:, 0:1, :] with the array `D`
345-
print("3: ", A)
366+
var A1 = nm.full[i16](3, 4, 5, fill_value=1)
367+
print("A1: ", A1)
368+
var D1 = nm.random.rand[i16](3, 5, min=0, max=100)
369+
A1[:, 0:1, :] = D1 # sets the elements of A[:, 0:1, :] with the array `D`
370+
print("A3: ", A1)
371+
var D = nm.random.rand[i16](4, 5, min=0, max=100)
372+
A1[1] = D # sets the elements of A[1:2, :, :] with the array `D`
373+
print("A2: ", A1)
346374

347375

348376
fn main() raises:
@@ -359,118 +387,3 @@ fn main() raises:
359387
# test_solve()
360388
# test_linalg()
361389
test_setter()
362-
363-
364-
# var x = numojo.full[numojo.f32](3, 2, fill_value=16.0)
365-
# var x = numojo.NDArray[numojo.f32](data=List[SIMD[numojo.f32, 1]](1,2,3,4,5,6,7,8,9,10,11,12), shape=List[Int](2,3,2),
366-
# order="F")
367-
# print(x)
368-
# print(x.stride)
369-
# var y = numojo.NDArray[numojo.f32](data=List[SIMD[numojo.f32, 1]](1,2,3,4,5,6,7,8,9,10,11,12), shape=List[Int](2,3,2),
370-
# order="C")
371-
# print(y)
372-
# print(y.stride)
373-
# print()
374-
# var summed = numojo.stats.sum(x,0)
375-
# print(summed)
376-
# print(numojo.stats.mean(x,0))
377-
# print(numojo.stats.cumprod(x))
378-
379-
# var maxval = x.max(axis=0)
380-
# print(maxval)
381-
382-
383-
# var array = nj.NDArray[nj.f64](10,10)
384-
# for i in range(array.size()):
385-
# array[i] = i
386-
# # for i in range(10):
387-
# # for j in range(10):
388-
# # print(array[i, j])
389-
# var res = array.sum(axis=0)
390-
# print(res)
391-
392-
# var arr2 = numojo.NDArray[numojo.f32](data=List[SIMD[numojo.f32, 1]](1.0, 2.0, 4.0, 7.0, 11.0, 16.0),
393-
# shape=List[Int](6))
394-
# var np = Python.import_module("numpy")
395-
# var np_arr = numojo.to_numpy(arr2)
396-
# print(np_arr)
397-
# var result = numojo.math.calculus.differentiation.gradient[numojo.f32](arr2, spacing=1.0)
398-
# print(result)
399-
# print(arr1.any())
400-
# print(arr1.all())
401-
# print(arr1.argmax())
402-
# print(arr1.argmin())
403-
# print(arr1.astype[numojo.i16]())
404-
# print(arr1.flatten(inplace=True))
405-
# print(r.ndshape, r.stride, r.ndshape.ndsize)
406-
# var t0 = time.now()
407-
# var res = numojo.math.linalg.matmul_tiled_unrolled_parallelized[numojo.f32](arr, arr1)
408-
# print((time.now()-t0)/1e9)
409-
# var res = numojo.math.linalg.matmul_tiled_unrolled_parallelized[numojo.f32](arr, arr1)
410-
# print(res)
411-
# print(arr)
412-
# print("2x3x1")
413-
# var sliced = arr[:, :, 1:2]
414-
# print(sliced)
415-
416-
# print("1x3x4")
417-
# var sliced1 = arr[::2, :]
418-
# print(sliced1)
419-
420-
# print("1x3x1")
421-
# var sliced2 = arr[1:2, :, 2:3]
422-
# print(sliced2)
423-
424-
# var result = numojo.NDArray(3, 3)
425-
# numojo.math.linalg.dot[t10=3, t11=3, t21=3, dtype=numojo.f32](result, arr, arr1)
426-
# print(result)
427-
428-
429-
# fn main() raises:
430-
# var size:VariadicList[Int] = VariadicList[Int](16,128,256,512,1024)
431-
# alias size1: StaticIntTuple[5] = StaticIntTuple[5](16,128,256,512,1024)
432-
# var times:List[Float64] = List[Float64]()
433-
# alias type:DType = DType.float64
434-
# measure_time[type, size1](size, times)
435-
436-
# fn measure_time[dtype:DType, size1: StaticIntTuple[5]](size:VariadicList[Int], inout times:List[Float64]) raises:
437-
438-
# for i in range(size.__len__()):
439-
# var arr1 = numojo.NDArray[dtype](size[i], size[i])
440-
# var arr2 = numojo.NDArray[dtype](size[i], size[i])
441-
# var arr_mul = numojo.NDArray[dtype](size[i], size[i])
442-
443-
# var t0 = time.now()
444-
# @parameter
445-
# for i in range(50):
446-
# numojo.math.linalg.dot[t10=size1[i], t11=size1[i], t21=size1[i], dtype=dtype](arr_mul, arr1, arr2)
447-
# # var arr_mul = numojo.math.linalg.matmul_parallelized[dtype](arr1, arr2)
448-
# # var arr_mul = numojo.math.linalg.matmul_tiled_unrolled_parallelized[dtype](arr1, arr2)
449-
# keep(arr_mul.unsafe_ptr())
450-
# times.append(((time.now()-t0)/1e9)/50)
451-
452-
# for i in range(size.__len__()):
453-
# print(times[i])
454-
455-
# fn main() raises:
456-
# alias type:DType = DType.float16
457-
# measure_time[type]()
458-
459-
# fn measure_time[dtype:DType]() raises:
460-
# var size:VariadicList[Int] = VariadicList[Int](16,128,256,512,1024)
461-
# alias size1: StaticIntTuple[5] = StaticIntTuple[5](16,128,256,512,1024)
462-
463-
# var n = 4
464-
# alias m = 4
465-
# var arr1 = numojo.NDArray[dtype](size[n], size[n])
466-
# var arr2 = numojo.NDArray[dtype](size[n], size[n])
467-
# var arr_mul = numojo.NDArray[dtype](size[n], size[n])
468-
469-
# var t0 = time.now()
470-
471-
# for _ in range(50):
472-
# numojo.math.linalg.dot[t10=size1[m], t11=size1[m], t21=size1[m], dtype=dtype](arr_mul, arr1, arr2)
473-
# # var arr_mul = numojo.math.linalg.matmul_parallelized[dtype](arr1, arr2)
474-
# # var arr_mul = numojo.math.linalg.matmul_tiled_unrolled_parallelized[dtype](arr1, arr2)
475-
# keep(arr_mul.unsafe_ptr())
476-
# print(((time.now()-t0)/1e9)/50)

0 commit comments

Comments
 (0)