From 3926b2619df5b348071a300657c4be4b96e6346c Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Wed, 7 May 2025 18:43:15 +0800 Subject: [PATCH 1/3] for analysis naive and lut gemm --- tests/test_matrix_ops.cpp | 85 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/test_matrix_ops.cpp diff --git a/tests/test_matrix_ops.cpp b/tests/test_matrix_ops.cpp new file mode 100644 index 0000000..bdef844 --- /dev/null +++ b/tests/test_matrix_ops.cpp @@ -0,0 +1,85 @@ +#include "matrix_ops.hpp" +#include "matrix.hpp" +#include +#include +#include + +void run_float() { + const size_t M = 1024; + const size_t K = 1024; + const size_t N = 1024; + + // 建立 float 矩陣 + Matrix> A_float(M, K); + Matrix> B_float(K, N); + + // 初始化矩陣 + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < K; ++j) { + A_float.set(i, j, static_cast(i + j) / 1000.0f); + } + } + + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + B_float.set(i, j, static_cast(i * j) / 1000.0f); + } + } + + auto start = std::chrono::high_resolution_clock::now(); + auto result = matmul(A_float, B_float); + auto end = std::chrono::high_resolution_clock::now(); + auto time = std::chrono::duration_cast(end - start).count(); + std::cout << "Float time: " << time << " ms\n"; +} + +void run_lut() { + const size_t M = 1024; + const size_t K = 1024; + const size_t N = 1024; + + // 建立 int4 矩陣 + Matrix A_int4(M, K); + Matrix B_int4(K, N); + + // 初始化矩陣 + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < K; ++j) { + A_int4.set(i, j, (i + j) % 16); + } + } + + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + B_int4.set(i, j, (i * j) % 16); + } + } + + // 建立查找表 + ProductLookupTable lut(16, 16); + + auto start = std::chrono::high_resolution_clock::now(); + auto result = matmul_lut_fast(unpack_int4(A_int4), unpack_int4(B_int4), M, K, N, lut); + auto end = std::chrono::high_resolution_clock::now(); + auto time = std::chrono::duration_cast(end - start).count(); + std::cout << "LUT time: " << time << " ms\n"; +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + std::cout << "Usage: " << argv[0] << " [float|lut]\n"; + return 1; + } + + std::string mode = argv[1]; + if (mode == "float") { + run_float(); + } else if (mode == "lut") { + run_lut(); + } else { + std::cout << "Invalid mode. Use 'float' or 'lut'\n"; + return 1; + } + + return 0; +} \ No newline at end of file From 63366b335f99321ba1d5a254ae93e366753ab7eb Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Wed, 7 May 2025 18:47:49 +0800 Subject: [PATCH 2/3] perf: add parallel SIMD implementation with blocking for both float and LUT versions --- Makefile | 35 ++++++-- src/matrix_ops.hpp | 215 +++++++++++++++++++++++++++++++-------------- 2 files changed, 179 insertions(+), 71 deletions(-) diff --git a/Makefile b/Makefile index 27e6671..db756b4 100644 --- a/Makefile +++ b/Makefile @@ -24,8 +24,12 @@ TEST_SRC := $(TEST_DIR)/run_benchmark.cpp TARGET_MAIN := $(BUILD_DIR)/run_benchmark # correctness suite -CORR_SRC := $(TEST_DIR)/test_correctness.cpp # fix filename -TARGET_CORR := $(BUILD_DIR)/test_correctness # fix binary name +CORR_SRC := $(TEST_DIR)/test_correctness.cpp +TARGET_CORR := $(BUILD_DIR)/test_correctness + +# matrix ops test +MATRIX_OPS_SRC := $(TEST_DIR)/test_matrix_ops.cpp +TARGET_MATRIX_OPS := $(BUILD_DIR)/test_matrix_ops HEADERS := \ $(SRC_DIR)/layout_policies.hpp \ @@ -35,24 +39,28 @@ HEADERS := \ $(SRC_DIR)/lut_utils.hpp \ $(SRC_DIR)/post_processing.hpp -.PHONY: all run test clean pytest +.PHONY: all run test clean pytest matrix_ops matrix_ops_float matrix_ops_lut -all: $(BUILD_DIR) $(TARGET_MAIN) $(TARGET_CORR) mpgemm$(PYEXT) +all: $(BUILD_DIR) $(TARGET_MAIN) $(TARGET_CORR) $(TARGET_MATRIX_OPS) mpgemm$(PYEXT) # ensure build directory exists $(BUILD_DIR): mkdir -p $(BUILD_DIR) # build main -$(TARGET_MAIN): $(TEST_SRC) $(HEADERS) +$(TARGET_MAIN): $(BUILD_DIR) $(TEST_SRC) $(HEADERS) $(CXX) $(CXXFLAGS) $(TEST_SRC) -o $(TARGET_MAIN) $(LDFLAGS) $(LDLIBS) # build correctness suite -$(TARGET_CORR): $(CORR_SRC) $(HEADERS) +$(TARGET_CORR): $(BUILD_DIR) $(CORR_SRC) $(HEADERS) $(CXX) $(CXXFLAGS) $(CORR_SRC) -o $(TARGET_CORR) $(LDFLAGS) $(LDLIBS) +# build matrix ops test +$(TARGET_MATRIX_OPS): $(BUILD_DIR) $(MATRIX_OPS_SRC) $(HEADERS) + $(CXX) $(CXXFLAGS) -pthread $(MATRIX_OPS_SRC) -o $(TARGET_MATRIX_OPS) $(LDFLAGS) $(LDLIBS) + # build pybind11 module -mpgemm$(PYEXT): src/bindings.cpp $(HEADERS) +mpgemm$(PYEXT): $(BUILD_DIR) src/bindings.cpp $(HEADERS) $(CXX) $(CXXFLAGS) $(PYBIND11_INC) -fPIC -shared src/bindings.cpp -o $@ $(LDFLAGS) $(LDLIBS) # run pytest @@ -65,6 +73,19 @@ run: all test: $(TARGET_CORR) ./$(TARGET_CORR) +matrix_ops: $(TARGET_MATRIX_OPS) + @echo "Running float version..." + @./$(TARGET_MATRIX_OPS) float + @echo "\nRunning LUT version..." + @./$(TARGET_MATRIX_OPS) lut + +# 方便的命令 +matrix_ops_float: $(TARGET_MATRIX_OPS) + ./$(TARGET_MATRIX_OPS) float + +matrix_ops_lut: $(TARGET_MATRIX_OPS) + ./$(TARGET_MATRIX_OPS) lut + clean: rm -rf $(BUILD_DIR) rm -f mpgemm$(PYEXT) \ No newline at end of file diff --git a/src/matrix_ops.hpp b/src/matrix_ops.hpp index fbac1b4..3651a28 100644 --- a/src/matrix_ops.hpp +++ b/src/matrix_ops.hpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include // ============================================================= // Helper: unpack a Matrix<> that uses Int4Storage into a @@ -29,82 +31,57 @@ std::vector unpack_int4(const Mat4& M) } // ============================================================= -// Naive reference GEMM (kept unchanged for correctness checks) +// High-performance parallel GEMM implementation +// Supports any numeric type through templates // ============================================================= template -auto matmul(const MA& A, const MB& B) +auto matmul(const MA& A, const MB& B, size_t num_threads = 4) { using T = decltype(A.at(0, 0)); static_assert(std::is_same_v, "Element types must match"); size_t M = A.rows(), K = A.cols(), N = B.cols(); Matrix> C(M, N); + std::vector threads; + + // 計算每個執行緒處理的行數 + size_t rows_per_thread = (M + num_threads - 1) / num_threads; + + // 為每個執行緒分配工作 + for (size_t t = 0; t < num_threads; ++t) { + threads.emplace_back([&, t]() { + size_t start_row = t * rows_per_thread; + size_t end_row = std::min(start_row + rows_per_thread, M); + + // 為每個執行緒創建局部結果矩陣 + Matrix> local_C(M, N); + + for (size_t i = start_row; i < end_row; ++i) { + for (size_t k = 0; k < K; ++k) { + T a = A.at(i, k); + for (size_t j = 0; j < N; ++j) { + local_C.set(i, j, local_C.at(i, j) + a * B.at(k, j)); + } + } + } - for (size_t i = 0; i < M; ++i) - for (size_t k = 0; k < K; ++k) { - T a = A.at(i, k); - for (size_t j = 0; j < N; ++j) - C.set(i, j, C.at(i, j) + a * B.at(k, j)); - } - return C; -} - -// ============================================================= -// High‑speed LUT GEMM — expects *unpacked* uint8 buffers. -// * Au shape: M × K contiguous -// * Bu shape: K × N contiguous -// Works with or without AVX2 (scalar fallback). -// ============================================================= - -auto matmul_lut_fast(const std::vector& Au, - const std::vector& Bu, - size_t M, size_t K, size_t N, - const ProductLookupTable& lut) -{ - const int32_t* lut_ptr = lut.data(); - const int32_t stride = static_cast(lut.row_stride()); - -#if defined(__AVX2__) - const __m256i vstride = _mm256_set1_epi32(stride); -#endif - - Matrix> C(M, N); - - for (size_t i = 0; i < M; ++i) { - const uint8_t* rowA = &Au[i * K]; - for (size_t j = 0; j < N; ++j) { - int32_t acc = 0; - size_t k = 0; - -#if defined(__AVX2__) - // --- AVX2: process 8 elements (k dimension) per iteration --- - for (; k + 7 < K; k += 8) { - // load 8 uint8 from A row / B column (row‑major & col‑major buffers) - __m128i w8 = _mm_loadl_epi64(reinterpret_cast(rowA + k)); - __m128i a8 = _mm_loadl_epi64(reinterpret_cast(&Bu[k * N + j])); - - __m256i w32 = _mm256_cvtepu8_epi32(w8); - __m256i a32 = _mm256_cvtepu8_epi32(a8); - __m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, vstride), a32); - __m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4); - - // horizontal sum of 8 lanes - __m128i low = _mm256_castsi256_si128(vals); - __m128i high = _mm256_extracti128_si256(vals, 1); - __m128i sum = _mm_add_epi32(low, high); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - acc += _mm_cvtsi128_si32(sum); + // 合併結果 + static std::mutex mtx; + std::lock_guard lock(mtx); + for (size_t i = start_row; i < end_row; ++i) { + for (size_t j = 0; j < N; ++j) { + C.set(i, j, C.at(i, j) + local_C.at(i, j)); + } } -#endif - // --- scalar remainder (or full loop if no AVX2) --- - for (; k < K; ++k) - acc += lut_ptr[rowA[k] * stride + Bu[k * N + j]]; + }); + } - C.set(i, j, acc); - } + // 等待所有執行緒完成 + for (auto& thread : threads) { + thread.join(); } + return C; } @@ -149,4 +126,114 @@ Matrix matmul_mkl(const Matrix& A, const Matrix& B) { } return C; } -#endif \ No newline at end of file +#endif + +// ============================================================= +// High‑speed LUT GEMM — expects *unpacked* uint8 buffers. +// * Au shape: M × K contiguous +// * Bu shape: K × N contiguous +// Works with or without AVX2 (scalar fallback). +// ============================================================= + +auto matmul_lut_fast(const std::vector& Au, + const std::vector& Bu, + size_t M, size_t K, size_t N, + const ProductLookupTable& lut, + size_t block_size = 64, // 增加區塊大小以減少同步開銷 + size_t num_threads = 4) +{ + const int32_t* lut_ptr = lut.data(); + const int32_t stride = static_cast(lut.row_stride()); + Matrix> C(M, N); + std::vector threads; + std::mutex mtx; + + // 預取 LUT 到 L1 cache + for (size_t i = 0; i < 16; ++i) { + for (size_t j = 0; j < 16; ++j) { + _mm_prefetch(reinterpret_cast(&lut_ptr[i * stride + j]), _MM_HINT_T0); + } + } + + // 計算每個執行緒處理的行數 + size_t rows_per_thread = (M + num_threads - 1) / num_threads; + + // 為每個執行緒分配工作 + for (size_t t = 0; t < num_threads; ++t) { + threads.emplace_back([&, t]() { + size_t start_row = t * rows_per_thread; + size_t end_row = std::min(start_row + rows_per_thread, M); + + // 為每個執行緒創建局部結果矩陣 + Matrix> local_C(M, N); + + // 分塊處理 + for (size_t i = start_row; i < end_row; i += block_size) { + size_t i_end = std::min(i + block_size, end_row); + + for (size_t j = 0; j < N; j += block_size) { + size_t j_end = std::min(j + block_size, N); + + for (size_t k = 0; k < K; k += block_size) { + size_t k_end = std::min(k + block_size, K); + + // 處理當前區塊 + for (size_t ii = i; ii < i_end; ++ii) { + const uint8_t* rowA = &Au[ii * K]; + + for (size_t jj = j; jj < j_end; ++jj) { + int32_t acc = 0; + +#if defined(__AVX2__) + // 使用 AVX2 處理 8 個元素 + for (size_t kk = k; kk + 7 < k_end; kk += 8) { + __m128i w8 = _mm_loadl_epi64(reinterpret_cast(rowA + kk)); + __m128i a8 = _mm_loadl_epi64(reinterpret_cast(&Bu[kk * N + jj])); + + __m256i w32 = _mm256_cvtepu8_epi32(w8); + __m256i a32 = _mm256_cvtepu8_epi32(a8); + __m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, _mm256_set1_epi32(stride)), a32); + + // 預取下一個 LUT 值 + _mm_prefetch(reinterpret_cast(&lut_ptr[_mm256_extract_epi32(idx, 0)]), _MM_HINT_T0); + + __m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4); + + // 水平加總 + __m128i low = _mm256_castsi256_si128(vals); + __m128i high = _mm256_extracti128_si256(vals, 1); + __m128i sum = _mm_add_epi32(low, high); + sum = _mm_hadd_epi32(sum, sum); + sum = _mm_hadd_epi32(sum, sum); + acc += _mm_cvtsi128_si32(sum); + } +#endif + // 處理剩餘元素 + for (size_t kk = k + ((k_end - k) & ~7); kk < k_end; ++kk) { + acc += lut_ptr[rowA[kk] * stride + Bu[kk * N + jj]]; + } + + local_C.set(ii, jj, local_C.at(ii, jj) + acc); + } + } + } + } + } + + // 合併結果 + std::lock_guard lock(mtx); + for (size_t i = start_row; i < end_row; ++i) { + for (size_t j = 0; j < N; ++j) { + C.set(i, j, C.at(i, j) + local_C.at(i, j)); + } + } + }); + } + + // 等待所有執行緒完成 + for (auto& thread : threads) { + thread.join(); + } + + return C; +} \ No newline at end of file From f4664fcede934b297d3c5294419358c06bde4430 Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Wed, 7 May 2025 22:32:32 +0800 Subject: [PATCH 3/3] remove redundant variable references --- Makefile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index db756b4..1794317 100644 --- a/Makefile +++ b/Makefile @@ -48,19 +48,19 @@ $(BUILD_DIR): mkdir -p $(BUILD_DIR) # build main -$(TARGET_MAIN): $(BUILD_DIR) $(TEST_SRC) $(HEADERS) +$(TARGET_MAIN): $(TEST_SRC) $(HEADERS) $(CXX) $(CXXFLAGS) $(TEST_SRC) -o $(TARGET_MAIN) $(LDFLAGS) $(LDLIBS) # build correctness suite -$(TARGET_CORR): $(BUILD_DIR) $(CORR_SRC) $(HEADERS) +$(TARGET_CORR): $(CORR_SRC) $(HEADERS) $(CXX) $(CXXFLAGS) $(CORR_SRC) -o $(TARGET_CORR) $(LDFLAGS) $(LDLIBS) # build matrix ops test -$(TARGET_MATRIX_OPS): $(BUILD_DIR) $(MATRIX_OPS_SRC) $(HEADERS) +$(TARGET_MATRIX_OPS): $(MATRIX_OPS_SRC) $(HEADERS) $(CXX) $(CXXFLAGS) -pthread $(MATRIX_OPS_SRC) -o $(TARGET_MATRIX_OPS) $(LDFLAGS) $(LDLIBS) # build pybind11 module -mpgemm$(PYEXT): $(BUILD_DIR) src/bindings.cpp $(HEADERS) +mpgemm$(PYEXT): src/bindings.cpp $(HEADERS) $(CXX) $(CXXFLAGS) $(PYBIND11_INC) -fPIC -shared src/bindings.cpp -o $@ $(LDFLAGS) $(LDLIBS) # run pytest