@@ -111,114 +111,3 @@ fn matmul_naive[
111111 C.store(m, n, val = C.load(m, n) + A.load(m, k) * B.load(k, n))
112112
113113 return C
114-
115-
116- @always_inline
117- fn calculate_block [
118- M : Int, N : Int, K : Int, BLOCK_M : Int, BLOCK_N : Int, nelts : Int, dtype : DType
119- ](
120- res : NDArray[dtype],
121- t1 : NDArray[dtype],
122- t2 : NDArray[dtype],
123- bm : Int,
124- bn : Int,
125- ) raises :
126- # Compute tile
127- var acc = stack_allocation[BLOCK_M * BLOCK_N , dtype]()
128- memset_zero[dtype](acc, BLOCK_M * BLOCK_N )
129-
130- for k in range (K):
131- # @unroll
132- for m in range (BLOCK_M ):
133-
134- @parameter
135- fn inner_n [nelts : Int](n : Int):
136- try :
137- acc.store[width=nelts](
138- m * BLOCK_N + n,
139- SIMD [dtype, nelts]
140- .splat(t1[(bm + m) * K + k])
141- .fma(
142- t2.load[width=nelts](k * N + (bn + n)),
143- acc.load[width=nelts](m * BLOCK_N + n),
144- ),
145- )
146- except e:
147- print (" Error" , e)
148-
149- vectorize[inner_n, nelts](BLOCK_N )
150-
151- # Store tile
152- for m in range (BLOCK_M ):
153-
154- @parameter
155- fn vec_store [nelts : Int](n : Int):
156- var temp = acc.load[width=nelts](m * BLOCK_N + n)
157- res.data.store[width=nelts]((bm + m) * N + (bn + n), val = temp)
158-
159- vectorize[vec_store, nelts](BLOCK_N )
160-
161-
162- @always_inline
163- fn dot [
164- t10 : Int, t11 : Int, t21 : Int, dtype : DType
165- ](res : NDArray[dtype], t1 : NDArray[dtype], t2 : NDArray[dtype]) raises :
166- alias M = t10 # t1[0]
167- alias K = t11 # t1[1], t2[0]
168- alias N = t21
169-
170- # simdwidthof[dtype]() = 8 for float32
171- alias nelts = simdwidthof[dtype]()
172- alias BLOCK_N = 8 * 2
173- alias BLOCK_M = 6
174- alias THREADS = 6 # num_logical_cores()
175-
176- alias BLOCK_N_REMAINDER = N % BLOCK_N
177- alias BLOCK_M_REMAINDER = M % BLOCK_M
178-
179- @parameter
180- fn bm_par (m_outer : Int):
181- var bm = m_outer * BLOCK_M
182-
183- for n_outer in range (0 , N // BLOCK_N ):
184- var bn = n_outer * BLOCK_N
185- try :
186- calculate_block[M, N, K, BLOCK_M , BLOCK_N , nelts](
187- res, t1, t2, bm, bn
188- )
189- except e:
190- print (" Error" , e)
191-
192- # Handle the remainder of N
193- @parameter
194- if BLOCK_N_REMAINDER > 0 :
195- var bn = N - BLOCK_N_REMAINDER
196- try :
197- calculate_block[M, N, K, BLOCK_M , BLOCK_N_REMAINDER , nelts](
198- res, t1, t2, bm, bn
199- )
200- except e:
201- print (" Error" , e)
202-
203- parallelize[bm_par](M // BLOCK_M , M // BLOCK_M )
204-
205- # Handle the remainder of M
206- @parameter
207- if BLOCK_M_REMAINDER > 0 :
208- var bm = M - BLOCK_M_REMAINDER
209-
210- for n_outer in range (0 , N // BLOCK_N ):
211- var bn = n_outer * BLOCK_N
212-
213- calculate_block[M, N, K, BLOCK_M_REMAINDER , BLOCK_N , nelts](
214- res, t1, t2, bm, bn
215- )
216-
217- # Handle corner remainder
218- @parameter
219- if BLOCK_N_REMAINDER > 0 :
220- var bn = N - BLOCK_N_REMAINDER
221-
222- calculate_block[
223- M, N, K, BLOCK_M_REMAINDER , BLOCK_N_REMAINDER , nelts
224- ](res, t1, t2, bm, bn)
0 commit comments