Skip to content

Commit 234afbf

Browse files
committed
fixed getitem adjust slice function errors
1 parent 2cca3f1 commit 234afbf

File tree

1 file changed

+44
-51
lines changed

1 file changed

+44
-51
lines changed

numojo/core/ndarray.mojo

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ struct NDArray[dtype: DType = DType.float64](
484484
var count: Int = 0
485485
var spec: List[Int] = List[Int]()
486486
for i in range(n_slices):
487-
self._adjust_slice_(slice_list[i], self.ndshape[i])
487+
# self._adjust_slice_(slice_list[i], self.ndshape[i])
488488
if (
489489
slice_list[i].start.value() >= self.ndshape[i]
490490
or slice_list[i].end.value() > self.ndshape[i]
@@ -611,43 +611,11 @@ struct NDArray[dtype: DType = DType.float64](
611611
Example:
612612
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
613613
"""
614-
print("slices: ", slices[0], slices[1], slices[2])
615614
var n_slices: Int = len(slices)
616615
var ndims: Int = 0
617616
var count: Int = 0
618617
var spec: List[Int] = List[Int]()
619-
var slice_list: List[Slice] = List[Slice]()
620-
for i in range(n_slices):
621-
var start: Int = 0
622-
var end: Int = 0
623-
if slices[i].start is None and slices[i].end is None:
624-
start = 0
625-
end = self.ndshape[i]
626-
temp = Slice(
627-
start=Optional(start),
628-
end=Optional(end),
629-
step=Optional(slices[i].step),
630-
)
631-
slice_list.append(temp)
632-
if slices[i].start is None and slices[i].end is not None:
633-
start = 0
634-
temp = Slice(
635-
start=Optional(start),
636-
end=Optional(slices[i].end.value()),
637-
step=Optional(slices[i].step),
638-
)
639-
slice_list.append(temp)
640-
if slices[i].start is not None and slices[i].end is None:
641-
end = self.ndshape[i]
642-
temp = Slice(
643-
start=Optional(slices[i].start.value()),
644-
end=Optional(end),
645-
step=Optional(slices[i].step),
646-
)
647-
slice_list.append(temp)
648-
if slices[i].start is not None and slices[i].end is not None:
649-
slice_list.append(slices[i])
650-
618+
var slice_list: List[Slice] = self._adjust_slice_(slices)
651619
for i in range(n_slices):
652620
if (
653621
slice_list[i].start.value() >= self.ndshape[i]
@@ -867,24 +835,48 @@ struct NDArray[dtype: DType = DType.float64](
867835
var idx: Int = _get_index(index, self.coefficient)
868836
return self.data.load[width=1](idx)
869837

870-
fn _adjust_slice_(self, inout span: Slice, dim: Int):
838+
fn _adjust_slice_(self, slice_list: List[Slice]) raises -> List[Slice]:
871839
"""
872840
Adjusts the slice values to lie within 0 and dim.
873841
"""
874-
if span.start or span.end:
875-
var start = int(span.start.value())
876-
var end = int(span.end.value())
877-
if start < 0:
878-
start = dim + start
879-
if not span.end:
880-
end = dim
881-
elif end < 0:
882-
end = dim + end
883-
if end > dim:
884-
end = dim
885-
if end < start:
842+
var n_slices: Int = slice_list.__len__()
843+
var slices = List[Slice]()
844+
for i in range(n_slices):
845+
var start: Int = 0
846+
var end: Int = 0
847+
if slice_list[i].start is None and slice_list[i].end is None:
848+
start = 0
849+
end = self.ndshape[i]
850+
temp = Slice(
851+
start=Optional(start),
852+
end=Optional(end),
853+
step=Optional(slice_list[i].step),
854+
)
855+
slices.append(temp)
856+
if slice_list[i].start is None and slice_list[i].end is not None:
886857
start = 0
887-
end = 0
858+
temp = Slice(
859+
start=Optional(start),
860+
end=Optional(slice_list[i].end.value()),
861+
step=Optional(slice_list[i].step),
862+
)
863+
slices.append(temp)
864+
if slice_list[i].start is not None and slice_list[i].end is None:
865+
end = self.ndshape[i]
866+
temp = Slice(
867+
start=Optional(slice_list[i].start.value()),
868+
end=Optional(end),
869+
step=Optional(slice_list[i].step),
870+
)
871+
slices.append(temp)
872+
if (
873+
slice_list[i].start is not None
874+
and slice_list[i].end is not None
875+
):
876+
slices.append(slice_list[i])
877+
else:
878+
raise Error("Error: Undefined Slice")
879+
return slices^
888880

889881
fn __getitem__(self, owned *slices: Slice) raises -> Self:
890882
"""
@@ -908,23 +900,24 @@ struct NDArray[dtype: DType = DType.float64](
908900
var narr: Self = self[slice_list]
909901
return narr
910902

911-
fn __getitem__(self, owned slices: List[Slice]) raises -> Self:
903+
fn __getitem__(self, owned slice_list: List[Slice]) raises -> Self:
912904
"""
913905
Retreive slices of an array from list of slices.
914906
915907
Example:
916908
`arr[1:3, 2:4]` returns the corresponding sliced array (2 x 2).
917909
"""
918910

919-
var n_slices: Int = slices.__len__()
911+
var n_slices: Int = slice_list.__len__()
920912
if n_slices > self.ndim or n_slices < self.ndim:
921913
raise Error("Error: No of slices do not match shape")
922914

923915
var ndims: Int = 0
924916
var spec: List[Int] = List[Int]()
925917
var count: Int = 0
918+
919+
var slices: List[Slice] = self._adjust_slice_(slice_list)
926920
for i in range(slices.__len__()):
927-
self._adjust_slice_(slices[i], self.ndshape[i])
928921
if (
929922
slices[i].start.value() >= self.ndshape[i]
930923
or slices[i].end.value() > self.ndshape[i]

0 commit comments

Comments
 (0)