Skip to content

Commit f85df10

Browse files
authored
[core] Improve NDIter to allow arbitrary axis to travel (#221)
- Improve `_NDIter` to allow arbitrary axis to travel. - Add method `ith()` to get the i-th item of the iterator. - Add `swapaxes()` for shape and strides. - Add `offset()` for `Item` type to get offset. - Constructor for `Item` from index and shape. - Add tests for C or F array with `nditer` from C or F orders.
1 parent c552c80 commit f85df10

File tree

9 files changed

+370
-68
lines changed

9 files changed

+370
-68
lines changed

docs/features.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Available functions and objects by topics (also see imports within `__init__` fi
88
- Array manipulation routines (`numojo.routines.manipulation`)
99
- Bit-wise operations (`numojo.routines.bitwise`)
1010
- Constants (`numojo.routines.constants`)
11+
- Indexing routines (`numojo.routines.indexing`)
1112
- Input and output (`numojo.routines.io`)
1213
- Text files (`numojo.routines.files`)
1314
- Text formatting options (`numojo.routines.formatting`)
@@ -20,7 +21,13 @@ Available functions and objects by topics (also see imports within `__init__` fi
2021
- Array contents (`numojo.routines.contents`)
2122
- Truth value testing (`numojo.routines.truth`)
2223
- Mathematical functions (`numojo.routines.math`)
23-
- Extrema: maxT, minT
24+
- Sums (`numojo.routines.sums`)
25+
- `sum()`
26+
- Products (`numojo.routines.products`)
27+
- `prod()`
28+
- Differences (`numojo.routines.differences`)
29+
- Extrema (`numojo.routines.math.extrema`)
30+
- `max()`, `min()`
2431
- Trigonometry: acos, asin, atan, cos, sin, tan, atan2, hypot
2532
- Hyperbolic: acosh, asinh, atanh, cosh, sinh, tanh
2633
- Floating: copysign
@@ -35,17 +42,14 @@ Available functions and objects by topics (also see imports within `__init__` fi
3542
- Indexing (`numojo.routines.indexing`)
3643
- Miscellaneous (`numojo.routines.misc`)
3744
- Rounding (`numojo.routines.rounding`)
38-
- Sums, products, differences (`numojo.routines.sums`, `numojo.routines.products`, `numojo.routines.differences`)
39-
- sum, prod
4045
- Trigonometric functions (`numojo.routines.trig`)
4146
- Random sampling (`numojo.routines.random`)
4247
- Sorting, searching, and counting (`numojo.routines.sorting`, `numojo.routines.searching`)
4348
- Statistics (`numojo.routines.statistics`)
44-
- mean, mode, median
45-
- variance
46-
- Averages and variances (`numojo.routines.averages`)
49+
- Averages and variances (`numojo.routines.averages`)
50+
- `mean()`, `mode()`, `median()`
51+
- `variance()`, `std()`
4752

4853
To-be-implemented functions and objects by topics:
4954

5055
- Mathematical functions: abs, floor, ceil, trunc, round, roundeven, round_half_down, round_half_up, reciprocal, nextafter, remainder
51-
- Statistical functions: pvariance, pstdev, std

numojo/core/item.mojo

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Implements Item type.
66

77
from builtin.type_aliases import Origin
88
from memory import UnsafePointer, memset_zero, memcpy
9+
from os import abort
910
from sys import simdwidthof
1011
from utils import Variant
1112

@@ -18,6 +19,10 @@ alias item = Item
1819

1920
@register_passable
2021
struct Item(CollectionElement):
22+
"""
23+
Specifies the indices of an item of an array.
24+
"""
25+
2126
var _buf: UnsafePointer[Int]
2227
var ndim: Int
2328

@@ -63,15 +68,14 @@ struct Item(CollectionElement):
6368
for i in range(self.ndim):
6469
(self._buf + i).init_pointee_copy(Int(args[i]))
6570

66-
@always_inline("nodebug")
6771
fn __init__(
6872
out self,
73+
*,
6974
ndim: Int,
7075
initialized: Bool,
7176
) raises:
7277
"""
7378
Construct Item with number of dimensions.
74-
7579
This method is useful when you want to create a Item with given ndim
7680
without knowing the Item values.
7781
@@ -80,15 +84,62 @@ struct Item(CollectionElement):
8084
initialized: Whether the shape is initialized.
8185
If yes, the values will be set to 0.
8286
If no, the values will be uninitialized.
87+
88+
Raises:
89+
Error: If the number of dimensions is negative.
8390
"""
8491
if ndim < 0:
85-
raise Error("Number of dimensions must be non-negative.")
92+
raise Error(
93+
"\nError in `Item.__init__()`: "
94+
"Number of dimensions must be non-negative."
95+
)
96+
8697
self.ndim = ndim
8798
self._buf = UnsafePointer[Int]().alloc(ndim)
8899
if initialized:
89100
for i in range(ndim):
90101
(self._buf + i).init_pointee_copy(0)
91102

103+
fn __init__(out self, idx: Int, shape: NDArrayShape) raises:
104+
"""
105+
Get indices of the i-th item of the array of the given shape.
106+
The item traverse the array in C-order.
107+
108+
Args:
109+
idx: The i-th item of the array.
110+
shape: The strides of the array.
111+
112+
Examples:
113+
114+
The following example demonstrates how to get the indices (coordinates)
115+
of the 123-th item of a 3D array with shape (20, 30, 40).
116+
117+
```console
118+
>>> from numojo.prelude import *
119+
>>> var item = Item(123, Shape(20, 30, 40))
120+
>>> print(item)
121+
Item at index: (0,3,3) Length: 3
122+
```
123+
"""
124+
125+
if (idx < 0) or (idx >= shape.size_of_array()):
126+
raise Error(
127+
String(
128+
"\nError in `Item.__init__(out self, idx: Int, shape:"
129+
" NDArrayShape)`: idx {} out of range [{}, {})."
130+
).format(idx, 0, shape.size_of_array())
131+
)
132+
133+
self.ndim = shape.ndim
134+
self._buf = UnsafePointer[Int]().alloc(self.ndim)
135+
136+
var strides = NDArrayStrides(shape, order="C")
137+
var remainder = idx
138+
139+
for i in range(self.ndim):
140+
(self._buf + i).init_pointee_copy(remainder // strides._buf[i])
141+
remainder %= strides._buf[i]
142+
92143
@always_inline("nodebug")
93144
fn __copyinit__(mut self, other: Self):
94145
"""Copy construct the tuple.
@@ -115,7 +166,7 @@ struct Item(CollectionElement):
115166

116167
@always_inline("nodebug")
117168
fn __getitem__[T: Indexer](self, idx: T) raises -> Int:
118-
"""Get the value at the specified index.
169+
"""Gets the value at the specified index.
119170
120171
Parameter:
121172
T: Type of values. It can be converted to `Int` with `Int()`.
@@ -205,6 +256,37 @@ struct Item(CollectionElement):
205256
+ String(self.ndim)
206257
)
207258

259+
# ===-------------------------------------------------------------------===#
260+
# Other methods
261+
# ===-------------------------------------------------------------------===#
262+
263+
fn offset(self, strides: NDArrayStrides) -> Int:
264+
"""
265+
Calculates the offset of the item according to strides.
266+
267+
Args:
268+
strides: The strides of the array.
269+
270+
Returns:
271+
The offset of the item.
272+
273+
Examples:
274+
275+
```mojo
276+
from numojo.prelude import *
277+
var item = Item(1, 2, 3)
278+
var strides = nm.Strides(4, 3, 2)
279+
print(item.offset(strides))
280+
# This prints `16`.
281+
```
282+
.
283+
"""
284+
285+
var offset: Int = 0
286+
for i in range(self.ndim):
287+
offset += self._buf[i] * strides._buf[i]
288+
return offset
289+
208290

209291
@value
210292
struct _ItemIter[

0 commit comments

Comments
 (0)