Skip to content

Commit 7d80fa8

Browse files
committed
removed basalt matmul from testing
1 parent e817c8f commit 7d80fa8

File tree

1 file changed

+0
-111
lines changed

1 file changed

+0
-111
lines changed

numojo/math/linalg/matmul.mojo

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -111,114 +111,3 @@ fn matmul_naive[
111111
C.store(m, n, val=C.load(m, n) + A.load(m, k) * B.load(k, n))
112112

113113
return C
114-
115-
116-
@always_inline
117-
fn calculate_block[
118-
M: Int, N: Int, K: Int, BLOCK_M: Int, BLOCK_N: Int, nelts: Int, dtype: DType
119-
](
120-
res: NDArray[dtype],
121-
t1: NDArray[dtype],
122-
t2: NDArray[dtype],
123-
bm: Int,
124-
bn: Int,
125-
) raises:
126-
# Compute tile
127-
var acc = stack_allocation[BLOCK_M * BLOCK_N, dtype]()
128-
memset_zero[dtype](acc, BLOCK_M * BLOCK_N)
129-
130-
for k in range(K):
131-
# @unroll
132-
for m in range(BLOCK_M):
133-
134-
@parameter
135-
fn inner_n[nelts: Int](n: Int):
136-
try:
137-
acc.store[width=nelts](
138-
m * BLOCK_N + n,
139-
SIMD[dtype, nelts]
140-
.splat(t1[(bm + m) * K + k])
141-
.fma(
142-
t2.load[width=nelts](k * N + (bn + n)),
143-
acc.load[width=nelts](m * BLOCK_N + n),
144-
),
145-
)
146-
except e:
147-
print("Error", e)
148-
149-
vectorize[inner_n, nelts](BLOCK_N)
150-
151-
# Store tile
152-
for m in range(BLOCK_M):
153-
154-
@parameter
155-
fn vec_store[nelts: Int](n: Int):
156-
var temp = acc.load[width=nelts](m * BLOCK_N + n)
157-
res.data.store[width=nelts]((bm + m) * N + (bn + n), val=temp)
158-
159-
vectorize[vec_store, nelts](BLOCK_N)
160-
161-
162-
@always_inline
163-
fn dot[
164-
t10: Int, t11: Int, t21: Int, dtype: DType
165-
](res: NDArray[dtype], t1: NDArray[dtype], t2: NDArray[dtype]) raises:
166-
alias M = t10 # t1[0]
167-
alias K = t11 # t1[1], t2[0]
168-
alias N = t21
169-
170-
# simdwidthof[dtype]() = 8 for float32
171-
alias nelts = simdwidthof[dtype]()
172-
alias BLOCK_N = 8 * 2
173-
alias BLOCK_M = 6
174-
alias THREADS = 6 # num_logical_cores()
175-
176-
alias BLOCK_N_REMAINDER = N % BLOCK_N
177-
alias BLOCK_M_REMAINDER = M % BLOCK_M
178-
179-
@parameter
180-
fn bm_par(m_outer: Int):
181-
var bm = m_outer * BLOCK_M
182-
183-
for n_outer in range(0, N // BLOCK_N):
184-
var bn = n_outer * BLOCK_N
185-
try:
186-
calculate_block[M, N, K, BLOCK_M, BLOCK_N, nelts](
187-
res, t1, t2, bm, bn
188-
)
189-
except e:
190-
print("Error", e)
191-
192-
# Handle the remainder of N
193-
@parameter
194-
if BLOCK_N_REMAINDER > 0:
195-
var bn = N - BLOCK_N_REMAINDER
196-
try:
197-
calculate_block[M, N, K, BLOCK_M, BLOCK_N_REMAINDER, nelts](
198-
res, t1, t2, bm, bn
199-
)
200-
except e:
201-
print("Error", e)
202-
203-
parallelize[bm_par](M // BLOCK_M, M // BLOCK_M)
204-
205-
# Handle the remainder of M
206-
@parameter
207-
if BLOCK_M_REMAINDER > 0:
208-
var bm = M - BLOCK_M_REMAINDER
209-
210-
for n_outer in range(0, N // BLOCK_N):
211-
var bn = n_outer * BLOCK_N
212-
213-
calculate_block[M, N, K, BLOCK_M_REMAINDER, BLOCK_N, nelts](
214-
res, t1, t2, bm, bn
215-
)
216-
217-
# Handle corner remainder
218-
@parameter
219-
if BLOCK_N_REMAINDER > 0:
220-
var bn = N - BLOCK_N_REMAINDER
221-
222-
calculate_block[
223-
M, N, K, BLOCK_M_REMAINDER, BLOCK_N_REMAINDER, nelts
224-
](res, t1, t2, bm, bn)

0 commit comments

Comments
 (0)