Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def nki_matmul_basic_(lhsT, rhs):

# Create a tensor in SBUF and copy the result from PSUM back to SBUF, and cast to expected output data-type
result_sbuf = nl.ndarray(result_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=result_sbuf, src=result_psum, dtype=result.dtype)
nisa.tensor_copy(dst=result_sbuf, src=result_psum)

# The result of a [64,128] x [128,512] matrix multiplication has a shape of [64, 512].
# This dictates which indices to use to address the result tile.
Expand Down Expand Up @@ -124,7 +124,7 @@ def nki_matmul_tiled_(lhsT, rhs):

# Copy the result from PSUM back to SBUF, and cast to expected output data-type
res_sb = nl.ndarray(res_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=res_sb, src=res_psum, dtype=result.dtype)
nisa.tensor_copy(dst=res_sb, src=res_psum)

# Copy the result from SBUF to HBM.
nisa.dma_copy(dst=result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N], src=res_sb)
Expand Down Expand Up @@ -203,8 +203,8 @@ def nki_matmul_hoist_load_(lhsT, rhs):
nisa.nc_matmul(dst=res_psum, stationary=lhsT_tiles[k], moving=rhs_tiles[k])

# Copy the result from PSUM back to SBUF, and cast to expected output data-type
res_sb = nl.ndarray(shape=(TILE_M, TILE_N), dtype=nl.float32, buffer=nl.sbuf)
nisa.tensor_copy(dst=res_sb, src=res_psum, dtype=result.dtype)
res_sb = nl.ndarray(shape=(TILE_M, TILE_N), dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=res_sb, src=res_psum)

# Copy the result from SBUF to HBM.
nisa.dma_copy(dst=result[m * TILE_M:(m + 1) * TILE_M, n * TILE_N:(n + 1) * TILE_N], src=res_sb)
Expand Down Expand Up @@ -386,8 +386,8 @@ def nki_matmul_fully_optimized_(
for bn_idx in range(TILES_IN_BLOCK_N):
# Create the result tile (uninitialized)
tile = nl.ndarray(shape=(TILE_M, TILE_N), dtype=lhsT.dtype, buffer=nl.sbuf)
# Initialize the tile 0.0
nisa.memset(dst=tile, value=0.0)
# Initialize the tile to 0
nisa.memset(dst=tile, value=0)
# Append the tile to block_n array.
block_n.append(tile)
# Append block_n array to block_m array.
Expand Down
Loading