diff --git a/Makefile b/Makefile index 444f41c..45464b2 100644 --- a/Makefile +++ b/Makefile @@ -83,7 +83,6 @@ matrix_ops: $(TARGET_MATRIX_OPS) @echo "\nRunning LUT version..." @./$(TARGET_MATRIX_OPS) lut -# 方便的命令 matrix_ops_float: $(TARGET_MATRIX_OPS) ./$(TARGET_MATRIX_OPS) float diff --git a/scripts/example.py b/scripts/example.py index 443c76d..09a2937 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -2,7 +2,6 @@ import os import sys -# 确保项目根目录在 PYTHONPATH sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, '..'))) import numpy as np @@ -10,49 +9,46 @@ from mpgemm import Activation def main(): - # 矩阵维度 + M, K, N = 12, 12, 12 - # 随机数生成器 rng = np.random.default_rng(2025) - # 生成随机 INT4 权重(-8 到 +7) + # generate random int4 weights in range [-8, 7] weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8) - # 將有符號 int4 轉換為無符號表示 weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8) - # 生成随机 FP16 激活(使用標準正態分佈) + # generate random activations activations = rng.standard_normal(size=(K, N)).astype(np.float16) - # 随机 bias(FP32,範圍也限制在合理範圍內) + bias = rng.uniform(-1, 1, size=N).astype(np.float32) - # 扁平化并转为 Python 列表 w_flat = weights_unsigned.flatten().tolist() a_flat = activations.flatten().astype(float).tolist() bias_list = bias.tolist() - # === 1. 基准参考输出 === + # === 1. Baseline: Naive GEMM === gemm_ref = mpgemm.Engine("naive") ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N) - # === 2. LUT 后端输出 === + # === 2. LUT GEMM === gemm_lut = mpgemm.Engine("lut") gemm_lut.generate_lut(bit_width=4) out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N) - # === 3. 后处理示例 === + # === 3. Post-processing === out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list) out_relu = gemm_lut.apply_activation(out_biased, M, N, Activation.ReLU) - # 还原成矩阵 + # === 4. Output === output = np.array(out_relu, dtype=np.float32).reshape(M, N) print(f"Output shape: {output.shape}") print("Sample row[0, :5]:", output[0, :5]) - # === 4. 误差分析示例 === + # === 5. Error measurement === stats = mpgemm.measure_error(ref_flat, out_flat) print(f"\nError relative to naive:") print(f" MSE = {stats['mse']:.6f}") diff --git a/src/gemm_engine.hpp b/src/gemm_engine.hpp index c9b3db6..2501217 100644 --- a/src/gemm_engine.hpp +++ b/src/gemm_engine.hpp @@ -21,7 +21,6 @@ enum class Backend { class Engine { public: - // 构造:支持 "naive"/"lut" (和 "mkl") Engine(const std::string &backend_str) : lut(nullptr) { @@ -33,7 +32,6 @@ class Engine { else throw std::invalid_argument("Unknown backend: " + backend_str); } - // 仅在 LUT 模式下调用,bit_width 一般传 4 void generate_lut(int bit_width) { if (backend != Backend::LUT) throw std::runtime_error("generate_lut only valid for LUT backend"); @@ -43,7 +41,6 @@ class Engine { lut = std::make_unique>(range, range); } - // matmul: Wflat 是 uint8_t(int4 存 raw),Aflat 是 float(list) std::vector matmul( const std::vector& Wflat, const std::vector& Aflat, @@ -55,16 +52,14 @@ class Engine { switch (backend) { case Backend::Naive: { Matrix> Wi(M, K), Ai(K, N); - // 填充 Wi, Ai,將 uint8_t 轉換為有符號 int for (int i = 0; i < M; ++i) for (int j = 0; j < K; ++j) { int val = Wflat[size_t(i)*K + j]; - Wi.set(i, j, val < 8 ? val : val - 16); // 轉換為有符號 + Wi.set(i, j, val < 8 ? val : val - 16); } for (int i = 0; i < K; ++i) for (int j = 0; j < N; ++j) Ai.set(i, j, (int)std::lround(Aflat[size_t(i)*N + j])); - // 调用全局 ::matmul auto C = ::matmul(Wi, Ai); for (int i = 0; i < M; ++i) for (int j = 0; j < N; ++j) @@ -73,22 +68,19 @@ class Engine { } case Backend::LUT: { if (!lut) throw std::runtime_error("LUT not generated"); - // 从 raw uint8 构造 Int4Storage 矩阵并 unpack + Matrix Wq(M,K); for (int i = 0; i < M; ++i) for (int j = 0; j < K; ++j) Wq.set(i,j, Wflat[size_t(i)*K + j]); auto Wu = unpack_int4(Wq); - // 把 Activation floats 截断、cast 为 uint8 索引 std::vector Au(size_t(K)*N); for (int i = 0; i < K; ++i) for (int j = 0; j < N; ++j) { float val = Aflat[size_t(i)*N + j]; - // 將浮點數轉換為有符號 int4 範圍 int q = std::lround(val); - q = std::clamp(q, -8, 7); // 限制在有符號 int4 範圍 - // 轉換為無符號表示 + q = std::clamp(q, -8, 7); Au[size_t(i)*N + j] = uint8_t(q < 0 ? q + 16 : q); } diff --git a/src/lut_utils.hpp b/src/lut_utils.hpp index c367f35..067ad6e 100644 --- a/src/lut_utils.hpp +++ b/src/lut_utils.hpp @@ -1,15 +1,19 @@ #pragma once + #include #include #include -#include // for posix_memalign, free +#include // posix_memalign, free #include #include -#include +#include -// 對齊分配器:以 Align 對齊分配 +// Aligned allocator with compile-time check template struct AlignedAllocator { + static_assert((Align & (Align - 1)) == 0 && Align >= alignof(T), + "Align must be power-of-two and at least alignof(T)"); + using value_type = T; using pointer = T*; using const_pointer = const T*; @@ -40,8 +44,8 @@ struct AlignedAllocator { } }; -// lookup table for product of two integers -// LUT:64-byte 對齊,並將每列 padding 到 8 的倍數 +// Lookup table for product of two integers, 64-byte aligned +// W: weight type, A: activation type, P: product type template > class ProductLookupTable { public: @@ -49,75 +53,41 @@ class ProductLookupTable { using ActivationType = A; using ProductType = P; - // 建構内部初始化,接受權重量化層數與 activation 長度 ProductLookupTable(std::size_t weight_levels, std::size_t a_range) : weight_levels_(weight_levels), a_range_(a_range), padded_a_range_(((a_range + 7) / 8) * 8), table_(weight_levels * padded_a_range_) { - // Default initialization: scalar LUT for weight × activation (raw) values - for (std::size_t w = 0; w < weight_levels_; ++w) { - int signed_w = static_cast(w < (weight_levels_/2) ? w : w - weight_levels_); - P* row_ptr = &table_[w * padded_a_range_]; - for (std::size_t a = 0; a < a_range_; ++a) { - // Activation index a treated as raw ActivationType - ActivationType act_val = static_cast(a); - int64_t prod = static_cast(signed_w) * static_cast(act_val); - // Saturate to ProductType range - if (prod > std::numeric_limits

::max()) prod = std::numeric_limits

::max(); - else if (prod < std::numeric_limits

::min()) prod = std::numeric_limits

::min(); - row_ptr[a] = static_cast

(prod); + // Default: build LUT with raw indices as activations + fill_impl([&](std::size_t a) -> int64_t { + int64_t raw = static_cast(a); + // two's-complement for 4-bit (uint8_t) activations + int64_t act = raw; + if constexpr (std::is_same::value) { + act = (raw < static_cast(weight_levels_ / 2)) + ? raw : (raw - static_cast(weight_levels_)); } - // padding region left uninitialized - } + return act; + }); } - /** - * fill_from_activation: - * 根據輸入的 activation row (act_row, 長度 = a_range_), - * 為所有 weight level 預先計算並填充查表數據: - * 對於每個 weight w (0..weight_levels_-1), - * 1. 轉為有符號值 signed_w - * 2. 遍歷所有 activation 值 act_row[a] - * 3. 計算 signed_w * act_row[a],並做飽和處理 - * 4. 存入 table_[w * padded_a_range_ + a] - * - * 這樣,後續 matmul 可直接對每個 weight lookup 整列乘積向量, - * 而不需即時計算乘法。 - */ - void fill_from_activation(const ActivationType* act_row) { - for (std::size_t w = 0; w < weight_levels_; ++w) { - int signed_w = static_cast(w < (weight_levels_/2) ? w : w - weight_levels_); - P* row_ptr = &table_[w * padded_a_range_]; - for (std::size_t a = 0; a < a_range_; ++a) { - // 讀取 activation 值;若是 uint8_t,則認為是 quantized 4-bit 需做 signed 轉換 - P act_val; - if constexpr (std::is_same::value) { - uint8_t raw = act_row[a]; - // two's complement mapping for 4-bit - act_val = static_cast

( raw < (weight_levels_/2) ? raw : raw - weight_levels_ ); - } else { - act_val = static_cast

(act_row[a]); - } - // 計算乘積 - - int64_t prod_int = static_cast(signed_w) * static_cast(act_val); - // 飽和處理 - if (prod_int > std::numeric_limits

::max()) prod_int = std::numeric_limits

::max(); - else if (prod_int < std::numeric_limits

::min()) prod_int = std::numeric_limits

::min(); - row_ptr[a] = static_cast

(prod_int); + void fill_from_activation(const ActivationType* act_row) noexcept { + // Refill LUT with actual activation values + fill_impl([&](std::size_t a) -> int64_t { + int64_t raw = static_cast(act_row[a]); + if constexpr (std::is_same::value) { + // two's-complement mapping for 4-bit + raw = (raw < static_cast(weight_levels_ / 2)) + ? raw : (raw - static_cast(weight_levels_)); } - // padding 部分不必初始化 - } + return raw; + }); } - // 取得 weight-level w 的向量查表指標 inline const P* get_row(std::size_t w) const noexcept { return &table_[w * padded_a_range_]; } - - // 元素查表,保留舊接口 inline ProductType get(std::size_t w, std::size_t a) const noexcept { return table_[w * padded_a_range_ + a]; } @@ -125,7 +95,6 @@ class ProductLookupTable { return get(w, a); } - // table 資料屬性 const P* data() const noexcept { return table_.data(); } std::size_t row_stride() const noexcept { return padded_a_range_; } std::size_t weight_levels() const noexcept { return weight_levels_; } @@ -134,5 +103,29 @@ class ProductLookupTable { private: std::size_t weight_levels_, a_range_, padded_a_range_; - std::vector> table_; // flat buffer (w * padded + a) + std::vector> table_; + + // Multiply and saturate in product type range + static P compute_prod(int64_t signed_w, int64_t act_val) noexcept { + int64_t prod = signed_w * act_val; + if (prod > std::numeric_limits

::max()) prod = std::numeric_limits

::max(); + else if (prod < std::numeric_limits

::min()) prod = std::numeric_limits

::min(); + return static_cast

(prod); + } + + // Core fill logic: computes table entries by combining signed weight and activation + template + void fill_impl(GetAct get_act) noexcept { + for (std::size_t w = 0; w < weight_levels_; ++w) { + int64_t signed_w = (w < weight_levels_ / 2) + ? static_cast(w) + : static_cast(w) - static_cast(weight_levels_); + P* row_ptr = &table_[w * padded_a_range_]; + for (std::size_t a = 0; a < a_range_; ++a) { + int64_t act = get_act(a); + row_ptr[a] = compute_prod(signed_w, act); + } + // padding intentionally left uninitialized for performance + } + } }; diff --git a/src/matrix_ops.hpp b/src/matrix_ops.hpp index a625726..664a324 100644 --- a/src/matrix_ops.hpp +++ b/src/matrix_ops.hpp @@ -46,16 +46,13 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4) Matrix> C(M, N); std::vector threads; - // 計算每個執行緒處理的行數 size_t rows_per_thread = (M + num_threads - 1) / num_threads; - // 為每個執行緒分配工作 for (size_t t = 0; t < num_threads; ++t) { threads.emplace_back([&, t]() { size_t start_row = t * rows_per_thread; size_t end_row = std::min(start_row + rows_per_thread, M); - // 為每個執行緒創建局部結果矩陣 Matrix> local_C(M, N); for (size_t i = start_row; i < end_row; ++i) { @@ -67,7 +64,6 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4) } } - // 合併結果 static std::mutex mtx; std::lock_guard lock(mtx); for (size_t i = start_row; i < end_row; ++i) { @@ -78,7 +74,6 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4) }); } - // 等待所有執行緒完成 for (auto& thread : threads) { thread.join(); } diff --git a/src/post_processing.hpp b/src/post_processing.hpp index aec487e..eaa0280 100644 --- a/src/post_processing.hpp +++ b/src/post_processing.hpp @@ -3,7 +3,7 @@ #include #include "matrix.hpp" -/// 後處理可選激活函式 + enum class Activation { Linear, ReLU, @@ -11,7 +11,6 @@ enum class Activation { Tanh }; -/// 1) bias 加法:對於每一列,將 bias[j] 加到 M(i,j) template Matrix add_bias( const Matrix& M, diff --git a/src/quant_utils.hpp b/src/quant_utils.hpp index 1c09e9f..67245a6 100644 --- a/src/quant_utils.hpp +++ b/src/quant_utils.hpp @@ -3,26 +3,24 @@ #include #include -// INT4 量化:fp16_val → uint8_t (低 4 bits) inline uint8_t quantize_int4(float fp16_val, float scale, int zero_point = 8) { - // 先將值限制在有符號 int4 範圍內(考慮 scale) + float min_val = -8.0f * scale; float max_val = 7.0f * scale; float clamped_val = std::clamp(fp16_val, min_val, max_val); - // 將值轉換為整數,並加上 zero_point + float scaled_val = clamped_val / scale; int q = static_cast(std::round(scaled_val)) + zero_point; - // 確保結果在 0-15 範圍內 + q = std::clamp(q, 0, 15); return static_cast(q); } -// INT4 反量化:uint8_t (低 4 bits) → float + inline float dequantize_int4(uint8_t q, float scale, int zero_point = 8) { - // 將無符號值轉換為有符號值 + int qi = static_cast(q) - zero_point; - // 轉換回浮點數 return static_cast(qi) * scale; } diff --git a/tests/run_benchmark.cpp b/tests/run_benchmark.cpp index feadd9d..46b2f7c 100644 --- a/tests/run_benchmark.cpp +++ b/tests/run_benchmark.cpp @@ -11,18 +11,17 @@ #include int main(int argc, char** argv) { - // 默认参数 + // default prameters int M = 500, K = 600, N = 500; bool run_naive_int = true; bool run_naive_float = true; bool run_lut = true; #ifdef USE_MKL - bool run_mkl = true; // 默认为 true + bool run_mkl = true; #else bool run_mkl = false; #endif - // 解析命令行 for (int i = 1; i < argc; ++i) { if (strcmp(argv[i], "--m")==0) M = std::atoi(argv[++i]); else if (strcmp(argv[i], "--k")==0) K = std::atoi(argv[++i]); @@ -34,17 +33,16 @@ int main(int argc, char** argv) { std::cout << "[Shape] M=" << M << ", K=" << K << ", N=" << N << "\n\n"; - // 随机数生成 + // random number generator std::mt19937 rng(12345); std::uniform_int_distribution dist_int(0, 100); - // 准备基准数据(int) + // generate random matrices Matrix> A_i(M, K), B_i(K, N); for (int i=0; i; using Int4C = Matrix; Int4R A4(M, K); @@ -56,7 +54,6 @@ int main(int argc, char** argv) { for (int j = 0; j < N; ++j) B4.set(k, j, static_cast(B_i.at(k, j) & 0x0F)); - // 再 unpack auto Au = unpack_int4(A4); auto Bu = unpack_int4(B4); ProductLookupTable lut(16,16); diff --git a/tests/test_correctness.cpp b/tests/test_correctness.cpp index 21e18c8..a920bde 100644 --- a/tests/test_correctness.cpp +++ b/tests/test_correctness.cpp @@ -226,18 +226,17 @@ bool run_int4_fast_test(){ using Mat4C = Matrix; std::mt19937 rng(42); - // 修改分佈範圍為 -8 到 7 std::uniform_int_distribution d4(-8,7); Mat4R A4(M,K); Mat4C B4(K,N); for(int i=0;i> Au_mat(M,K), Bu_mat(K,N); for(int i=0;i> M(2,3); M.set(0,0,1); M.set(0,1,2); M.set(0,2,3); M.set(1,0,4); M.set(1,1,5); M.set(1,2,6); @@ -350,7 +349,6 @@ bool run_sigmoid_test() { M.set(0,2,-2.0f); auto R = apply_activation(M, Activation::Sigmoid); - // 理論值 float s0 = 1.0f/(1+std::exp(-0.0f)); // 0.5 float s1 = 1.0f/(1+std::exp(-2.0f)); float s2 = 1.0f/(1+std::exp( 2.0f)); @@ -435,7 +433,7 @@ int main() { if (run_linear_test()) ++passed; if (run_accuracy_test()) ++passed; #ifdef USE_MKL - ++total; // 只有啟用 MKL 才加總數 + ++total; // MKL test is optional if (run_mkl_test()) ++passed; #endif std::cout << "\nTotal: " << passed << "/" << total << " tests passed.\n"; diff --git a/tests/test_matrix_ops.cpp b/tests/test_matrix_ops.cpp index bdef844..8ffb46c 100644 --- a/tests/test_matrix_ops.cpp +++ b/tests/test_matrix_ops.cpp @@ -9,11 +9,9 @@ void run_float() { const size_t K = 1024; const size_t N = 1024; - // 建立 float 矩陣 Matrix> A_float(M, K); Matrix> B_float(K, N); - // 初始化矩陣 for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < K; ++j) { A_float.set(i, j, static_cast(i + j) / 1000.0f); @@ -38,11 +36,9 @@ void run_lut() { const size_t K = 1024; const size_t N = 1024; - // 建立 int4 矩陣 Matrix A_int4(M, K); Matrix B_int4(K, N); - // 初始化矩陣 for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < K; ++j) { A_int4.set(i, j, (i + j) % 16); @@ -55,7 +51,6 @@ void run_lut() { } } - // 建立查找表 ProductLookupTable lut(16, 16); auto start = std::chrono::high_resolution_clock::now();