Skip to content

Commit 28cb84c

Browse files
committed
Fix issue in matmul due to initialization of the C.
1 parent 3b132e3 commit 28cb84c

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

numojo/core/ndshape.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from builtin.type_aliases import AnyLifetime
1010

1111
alias Shp = NDArrayShape
1212

13+
1314
@register_passable("trivial")
1415
struct NDArrayShape[dtype: DType = DType.int32](Stringable, Formattable):
1516
"""Implements the NDArrayShape."""

numojo/math/linalg/matmul.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn matmul_parallelized[
8686

8787
alias width = max(simdwidthof[dtype](), 16)
8888

89-
var C: NDArray[dtype] = NDArray[dtype](
89+
var C: NDArray[dtype] = zeros[dtype](
9090
A.ndshape.load_int(0), B.ndshape.load_int(1)
9191
)
9292
var t0 = A.ndshape.load_int(0)
@@ -125,7 +125,7 @@ fn matmul_naive[
125125
"""
126126
Matrix multiplication with three nested loops.
127127
"""
128-
var C: NDArray[dtype] = NDArray[dtype](
128+
var C: NDArray[dtype] = zeros[dtype](
129129
A.ndshape.load_int(0), B.ndshape.load_int(1)
130130
)
131131
for m in range(C.ndshape.load_int(0)):

numojo/prelude.mojo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ from numojo.prelude import *
2222

2323
from .core.ndarray import NDArray
2424
from .core.index import Idx
25-
from .core.ndarrayshape import NDArrayShape
25+
from .core.ndshape import NDArrayShape
2626
from .core.datatypes import i8, i16, i32, i64, u8, u16, u32, u64, f16, f32, f64

tests/test_math.mojo

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numojo as nm
2+
from numojo.prelude import *
23
from time import now
34
from python import Python, PythonObject
45
from utils_for_test import check, check_is_close
@@ -56,6 +57,17 @@ def test_sin_par():
5657

5758

5859
# ! MATMUL RESULTS IN A SEGMENTATION FAULT EXCEPT FOR NAIVE ONE, BUT NAIVE OUTPUTS WRONG VALUES
60+
61+
62+
def test_matmul_small():
63+
var np = Python.import_module("numpy")
64+
var arr = nm.ones[i8](4, 4)
65+
var np_arr = np.ones((4, 4), dtype=np.int8)
66+
check_is_close(
67+
arr @ arr, np.matmul(np_arr, np_arr), "Dunder matmul is broken"
68+
)
69+
70+
5971
def test_matmul():
6072
var np = Python.import_module("numpy")
6173
var arr = nm.arange[nm.f64](0, 100)

0 commit comments

Comments
 (0)