Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ TEST_SRC := $(TEST_DIR)/run_benchmark.cpp
TARGET_MAIN := $(BUILD_DIR)/run_benchmark

# correctness suite
CORR_SRC := $(TEST_DIR)/test_correctness.cpp # fix filename
TARGET_CORR := $(BUILD_DIR)/test_correctness # fix binary name
CORR_SRC := $(TEST_DIR)/test_correctness.cpp
TARGET_CORR := $(BUILD_DIR)/test_correctness

# matrix ops test
MATRIX_OPS_SRC := $(TEST_DIR)/test_matrix_ops.cpp
TARGET_MATRIX_OPS := $(BUILD_DIR)/test_matrix_ops

HEADERS := \
$(SRC_DIR)/layout_policies.hpp \
Expand All @@ -35,9 +39,9 @@ HEADERS := \
$(SRC_DIR)/lut_utils.hpp \
$(SRC_DIR)/post_processing.hpp

.PHONY: all run test clean pytest
.PHONY: all run test clean pytest matrix_ops matrix_ops_float matrix_ops_lut

all: $(BUILD_DIR) $(TARGET_MAIN) $(TARGET_CORR) mpgemm$(PYEXT)
all: $(BUILD_DIR) $(TARGET_MAIN) $(TARGET_CORR) $(TARGET_MATRIX_OPS) mpgemm$(PYEXT)

# ensure build directory exists
$(BUILD_DIR):
Expand All @@ -48,11 +52,15 @@ $(TARGET_MAIN): $(TEST_SRC) $(HEADERS)
$(CXX) $(CXXFLAGS) $(TEST_SRC) -o $(TARGET_MAIN) $(LDFLAGS) $(LDLIBS)

# build correctness suite
$(TARGET_CORR): $(CORR_SRC) $(HEADERS)
$(TARGET_CORR): $(CORR_SRC) $(HEADERS)
$(CXX) $(CXXFLAGS) $(CORR_SRC) -o $(TARGET_CORR) $(LDFLAGS) $(LDLIBS)

# build matrix ops test
$(TARGET_MATRIX_OPS): $(MATRIX_OPS_SRC) $(HEADERS)
$(CXX) $(CXXFLAGS) -pthread $(MATRIX_OPS_SRC) -o $(TARGET_MATRIX_OPS) $(LDFLAGS) $(LDLIBS)

# build pybind11 module
mpgemm$(PYEXT): src/bindings.cpp $(HEADERS)
mpgemm$(PYEXT): src/bindings.cpp $(HEADERS)
$(CXX) $(CXXFLAGS) $(PYBIND11_INC) -fPIC -shared src/bindings.cpp -o $@ $(LDFLAGS) $(LDLIBS)

# run pytest
Expand All @@ -65,6 +73,19 @@ run: all
test: $(TARGET_CORR)
./$(TARGET_CORR)

matrix_ops: $(TARGET_MATRIX_OPS)
@echo "Running float version..."
@./$(TARGET_MATRIX_OPS) float
@echo "\nRunning LUT version..."
@./$(TARGET_MATRIX_OPS) lut

# 方便的命令
matrix_ops_float: $(TARGET_MATRIX_OPS)
./$(TARGET_MATRIX_OPS) float

matrix_ops_lut: $(TARGET_MATRIX_OPS)
./$(TARGET_MATRIX_OPS) lut

clean:
rm -rf $(BUILD_DIR)
rm -f mpgemm$(PYEXT)
215 changes: 151 additions & 64 deletions src/matrix_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <type_traits>
#include <vector>
#include <immintrin.h>
#include <thread>
#include <mutex>

// =============================================================
// Helper: unpack a Matrix<> that uses Int4Storage into a
Expand All @@ -29,82 +31,57 @@ std::vector<uint8_t> unpack_int4(const Mat4& M)
}

// =============================================================
// Naive reference GEMM (kept unchanged for correctness checks)
// High-performance parallel GEMM implementation
// Supports any numeric type through templates
// =============================================================

template<typename MA, typename MB>
auto matmul(const MA& A, const MB& B)
auto matmul(const MA& A, const MB& B, size_t num_threads = 4)
{
using T = decltype(A.at(0, 0));
static_assert(std::is_same_v<T, decltype(B.at(0, 0))>, "Element types must match");

size_t M = A.rows(), K = A.cols(), N = B.cols();
Matrix<T, RowMajor, PlainStorage<T>> C(M, N);
std::vector<std::thread> threads;

// 計算每個執行緒處理的行數
size_t rows_per_thread = (M + num_threads - 1) / num_threads;

// 為每個執行緒分配工作
for (size_t t = 0; t < num_threads; ++t) {
threads.emplace_back([&, t]() {
size_t start_row = t * rows_per_thread;
size_t end_row = std::min(start_row + rows_per_thread, M);

// 為每個執行緒創建局部結果矩陣
Matrix<T, RowMajor, PlainStorage<T>> local_C(M, N);

for (size_t i = start_row; i < end_row; ++i) {
for (size_t k = 0; k < K; ++k) {
T a = A.at(i, k);
for (size_t j = 0; j < N; ++j) {
local_C.set(i, j, local_C.at(i, j) + a * B.at(k, j));
}
}
}

for (size_t i = 0; i < M; ++i)
for (size_t k = 0; k < K; ++k) {
T a = A.at(i, k);
for (size_t j = 0; j < N; ++j)
C.set(i, j, C.at(i, j) + a * B.at(k, j));
}
return C;
}

// =============================================================
// High‑speed LUT GEMM — expects *unpacked* uint8 buffers.
// * Au shape: M × K contiguous
// * Bu shape: K × N contiguous
// Works with or without AVX2 (scalar fallback).
// =============================================================

auto matmul_lut_fast(const std::vector<uint8_t>& Au,
const std::vector<uint8_t>& Bu,
size_t M, size_t K, size_t N,
const ProductLookupTable<uint8_t, uint8_t, int32_t>& lut)
{
const int32_t* lut_ptr = lut.data();
const int32_t stride = static_cast<int32_t>(lut.row_stride());

#if defined(__AVX2__)
const __m256i vstride = _mm256_set1_epi32(stride);
#endif

Matrix<int32_t, RowMajor, PlainStorage<int32_t>> C(M, N);

for (size_t i = 0; i < M; ++i) {
const uint8_t* rowA = &Au[i * K];
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
size_t k = 0;

#if defined(__AVX2__)
// --- AVX2: process 8 elements (k dimension) per iteration ---
for (; k + 7 < K; k += 8) {
// load 8 uint8 from A row / B column (row‑major & col‑major buffers)
__m128i w8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(rowA + k));
__m128i a8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&Bu[k * N + j]));

__m256i w32 = _mm256_cvtepu8_epi32(w8);
__m256i a32 = _mm256_cvtepu8_epi32(a8);
__m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, vstride), a32);
__m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4);

// horizontal sum of 8 lanes
__m128i low = _mm256_castsi256_si128(vals);
__m128i high = _mm256_extracti128_si256(vals, 1);
__m128i sum = _mm_add_epi32(low, high);
sum = _mm_hadd_epi32(sum, sum);
sum = _mm_hadd_epi32(sum, sum);
acc += _mm_cvtsi128_si32(sum);
// 合併結果
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
for (size_t i = start_row; i < end_row; ++i) {
for (size_t j = 0; j < N; ++j) {
C.set(i, j, C.at(i, j) + local_C.at(i, j));
}
}
#endif
// --- scalar remainder (or full loop if no AVX2) ---
for (; k < K; ++k)
acc += lut_ptr[rowA[k] * stride + Bu[k * N + j]];
});
}

C.set(i, j, acc);
}
// 等待所有執行緒完成
for (auto& thread : threads) {
thread.join();
}

return C;
}

Expand Down Expand Up @@ -149,4 +126,114 @@ Matrix<T> matmul_mkl(const Matrix<T>& A, const Matrix<T>& B) {
}
return C;
}
#endif
#endif

// =============================================================
// High‑speed LUT GEMM — expects *unpacked* uint8 buffers.
// * Au shape: M × K contiguous
// * Bu shape: K × N contiguous
// Works with or without AVX2 (scalar fallback).
// =============================================================

auto matmul_lut_fast(const std::vector<uint8_t>& Au,
const std::vector<uint8_t>& Bu,
size_t M, size_t K, size_t N,
const ProductLookupTable<uint8_t, uint8_t, int32_t>& lut,
size_t block_size = 64, // 增加區塊大小以減少同步開銷
size_t num_threads = 4)
{
const int32_t* lut_ptr = lut.data();
const int32_t stride = static_cast<int32_t>(lut.row_stride());
Matrix<int32_t, RowMajor, PlainStorage<int32_t>> C(M, N);
std::vector<std::thread> threads;
std::mutex mtx;

// 預取 LUT 到 L1 cache
for (size_t i = 0; i < 16; ++i) {
for (size_t j = 0; j < 16; ++j) {
_mm_prefetch(reinterpret_cast<const char*>(&lut_ptr[i * stride + j]), _MM_HINT_T0);
}
}

// 計算每個執行緒處理的行數
size_t rows_per_thread = (M + num_threads - 1) / num_threads;

// 為每個執行緒分配工作
for (size_t t = 0; t < num_threads; ++t) {
threads.emplace_back([&, t]() {
size_t start_row = t * rows_per_thread;
size_t end_row = std::min(start_row + rows_per_thread, M);

// 為每個執行緒創建局部結果矩陣
Matrix<int32_t, RowMajor, PlainStorage<int32_t>> local_C(M, N);

// 分塊處理
for (size_t i = start_row; i < end_row; i += block_size) {
size_t i_end = std::min(i + block_size, end_row);

for (size_t j = 0; j < N; j += block_size) {
size_t j_end = std::min(j + block_size, N);

for (size_t k = 0; k < K; k += block_size) {
size_t k_end = std::min(k + block_size, K);

// 處理當前區塊
for (size_t ii = i; ii < i_end; ++ii) {
const uint8_t* rowA = &Au[ii * K];

for (size_t jj = j; jj < j_end; ++jj) {
int32_t acc = 0;

#if defined(__AVX2__)
// 使用 AVX2 處理 8 個元素
for (size_t kk = k; kk + 7 < k_end; kk += 8) {
__m128i w8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(rowA + kk));
__m128i a8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&Bu[kk * N + jj]));

__m256i w32 = _mm256_cvtepu8_epi32(w8);
__m256i a32 = _mm256_cvtepu8_epi32(a8);
__m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, _mm256_set1_epi32(stride)), a32);

// 預取下一個 LUT 值
_mm_prefetch(reinterpret_cast<const char*>(&lut_ptr[_mm256_extract_epi32(idx, 0)]), _MM_HINT_T0);

__m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4);

// 水平加總
__m128i low = _mm256_castsi256_si128(vals);
__m128i high = _mm256_extracti128_si256(vals, 1);
__m128i sum = _mm_add_epi32(low, high);
sum = _mm_hadd_epi32(sum, sum);
sum = _mm_hadd_epi32(sum, sum);
acc += _mm_cvtsi128_si32(sum);
}
#endif
// 處理剩餘元素
for (size_t kk = k + ((k_end - k) & ~7); kk < k_end; ++kk) {
acc += lut_ptr[rowA[kk] * stride + Bu[kk * N + jj]];
}

local_C.set(ii, jj, local_C.at(ii, jj) + acc);
}
}
}
}
}

// 合併結果
std::lock_guard<std::mutex> lock(mtx);
for (size_t i = start_row; i < end_row; ++i) {
for (size_t j = 0; j < N; ++j) {
C.set(i, j, C.at(i, j) + local_C.at(i, j));
}
}
});
}

// 等待所有執行緒完成
for (auto& thread : threads) {
thread.join();
}

return C;
}
Loading