From 3a661fbab812833d4b1f40fc81a6ecaaee9f6673 Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Fri, 9 May 2025 13:24:25 +0800 Subject: [PATCH 1/4] int4 quant support --- scripts/example.py | 9 ++++++--- src/gemm_engine.hpp | 18 ++++++++++++------ src/quant_utils.hpp | 19 +++++++++++++++---- tests/test_correctness.cpp | 11 ++++++++--- 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/scripts/example.py b/scripts/example.py index 61f70e2..68704eb 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -16,15 +16,18 @@ def main(): # 随机数生成器 rng = np.random.default_rng(2025) - # 生成随机 INT4 权重(0-15) - weights = rng.integers(0, 16, size=(M, K), dtype=np.uint8) + # 生成随机 INT4 权重(-8 到 +7) + weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8) + # 將有符號 int4 轉換為無符號表示 + weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8) + # 生成随机 FP16 激活 activations = rng.standard_normal(size=(K, N)).astype(np.float16) # 随机 bias(FP32) bias = rng.standard_normal(size=N).astype(np.float32) # 扁平化并转为 Python 列表 - w_flat = weights.flatten().tolist() + w_flat = weights_unsigned.flatten().tolist() a_flat = activations.flatten().astype(float).tolist() bias_list = bias.tolist() diff --git a/src/gemm_engine.hpp b/src/gemm_engine.hpp index 93dbef2..86bd962 100644 --- a/src/gemm_engine.hpp +++ b/src/gemm_engine.hpp @@ -55,10 +55,12 @@ class Engine { switch (backend) { case Backend::Naive: { Matrix> Wi(M, K), Ai(K, N); - // 填充 Wi, Ai + // 填充 Wi, Ai,將 uint8_t 轉換為有符號 int for (int i = 0; i < M; ++i) - for (int j = 0; j < K; ++j) - Wi.set(i, j, (int)Wflat[size_t(i)*K + j]); + for (int j = 0; j < K; ++j) { + int val = Wflat[size_t(i)*K + j]; + Wi.set(i, j, val < 8 ? val : val - 16); // 轉換為有符號 + } for (int i = 0; i < K; ++i) for (int j = 0; j < N; ++j) Ai.set(i, j, (int)std::lround(Aflat[size_t(i)*N + j])); @@ -77,11 +79,15 @@ class Engine { for (int j = 0; j < K; ++j) Wq.set(i,j, Wflat[size_t(i)*K + j]); auto Wu = unpack_int4(Wq); - // 把 Activation floats 截断、cast 为 uint8 索引 + // 把 Activation floats 截断、cast 为 uint8 索引,並轉換為有符號 std::vector Au(size_t(K)*N); for (int i = 0; i < K; ++i) - for (int j = 0; j < N; ++j) - Au[size_t(i)*N + j] = uint8_t(std::lround(Aflat[size_t(i)*N + j])); + for (int j = 0; j < N; ++j) { + float val = Aflat[size_t(i)*N + j]; + int q = std::lround(val); + q = std::clamp(q, -8, 7); // 限制在有符號 int4 範圍 + Au[size_t(i)*N + j] = uint8_t(q < 0 ? q + 16 : q); + } auto Ci = matmul_lut_fast(Wu, Au, M, K, N, *lut); for (int i = 0; i < M; ++i) for (int j = 0; j < N; ++j) diff --git a/src/quant_utils.hpp b/src/quant_utils.hpp index 55136fe..1c09e9f 100644 --- a/src/quant_utils.hpp +++ b/src/quant_utils.hpp @@ -4,14 +4,25 @@ #include // INT4 量化:fp16_val → uint8_t (低 4 bits) -inline uint8_t quantize_int4(float fp16_val, float scale, int zero_point = 0) { - int q = static_cast(std::round(fp16_val / scale)) + zero_point; - q = std::clamp(q, 0, 15); // INT4 無號 0‥15 +inline uint8_t quantize_int4(float fp16_val, float scale, int zero_point = 8) { + // 先將值限制在有符號 int4 範圍內(考慮 scale) + float min_val = -8.0f * scale; + float max_val = 7.0f * scale; + float clamped_val = std::clamp(fp16_val, min_val, max_val); + + // 將值轉換為整數,並加上 zero_point + float scaled_val = clamped_val / scale; + int q = static_cast(std::round(scaled_val)) + zero_point; + + // 確保結果在 0-15 範圍內 + q = std::clamp(q, 0, 15); return static_cast(q); } // INT4 反量化:uint8_t (低 4 bits) → float -inline float dequantize_int4(uint8_t q, float scale, int zero_point = 0) { +inline float dequantize_int4(uint8_t q, float scale, int zero_point = 8) { + // 將無符號值轉換為有符號值 int qi = static_cast(q) - zero_point; + // 轉換回浮點數 return static_cast(qi) * scale; } diff --git a/tests/test_correctness.cpp b/tests/test_correctness.cpp index 233f745..7719a3d 100644 --- a/tests/test_correctness.cpp +++ b/tests/test_correctness.cpp @@ -253,12 +253,17 @@ bool run_int4_fast_test(){ // 7. Quantization/Dequantization test bool run_quant_dequant_test() { std::cout << "Running INT4 quant-dequant test...\n"; - float scale = 0.25f; // 假設 + float scale = 1.0f; // 假設 bool pass = true; - for (float v : {0.0f, 1.0f, 2.25f, 3.5f}) { + // 測試有符號 int4 範圍 (-8 到 +7) + for (float v : {-8.0f, -4.0f, 0.0f, 4.0f, 7.0f}) { uint8_t q = quantize_int4(v, scale); float d = dequantize_int4(q, scale); - if (std::abs(d - std::round(v/scale)*scale) > 1e-3f) pass = false; + if (std::abs(d - v) > 1e-3f) { + std::cout << "Failed at value " << v << ": quantized=" << (int)q + << ", dequantized=" << d << "\n"; + pass = false; + } } std::cout << (pass ? "Quant-Dequant test PASS\n" : "FAIL\n"); return pass; From e3bdfaa6f18c3f5b64b3e089b8f8046e053012b0 Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Tue, 13 May 2025 15:24:32 +0800 Subject: [PATCH 2/4] fix proposal img link --- doc/proposal.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/proposal.md b/doc/proposal.md index d0a7cfe..0318005 100644 --- a/doc/proposal.md +++ b/doc/proposal.md @@ -11,7 +11,7 @@ common precision, as hardware lacks native support for mixed-precision matrix multiplication (mpGEMM). Some recent research suggests using lookup tables (LUTs) to replace dequantization, further reducing computational overhead. -![LUT](./img/lut.png) +![LUT](../img/lut.png) ## Problem to Solve From ef5ccb0f8458449b99b33cdd11fc0d130b242141 Mon Sep 17 00:00:00 2001 From: 5000user5000 Date: Wed, 14 May 2025 15:19:54 +0800 Subject: [PATCH 3/4] feat:int4 support --- Makefile | 6 +- scripts/example.py | 17 ++++-- src/gemm_engine.hpp | 10 +++- src/lut_utils.hpp | 105 +++++++++++++++++++++++++++-------- src/matrix_ops.hpp | 111 ++++++++++--------------------------- tests/test_correctness.cpp | 27 +++++++-- 6 files changed, 157 insertions(+), 119 deletions(-) diff --git a/Makefile b/Makefile index 1794317..444f41c 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CXX := g++ -CXXFLAGS := -std=c++17 -O2 -Wall -I./src -march=native +CXXFLAGS := -std=c++17 -O2 -Wall -I./src -march=native -fopenmp PYBIND11_INC := $(shell python3 -m pybind11 --includes) PYEXT := $(shell python3-config --extension-suffix) @@ -14,6 +14,10 @@ ifeq ($(USE_MKL),1) endif # -------------------- +# ---- OpenMP toggle ---- +LDLIBS += -lgomp +# -------------------- + SRC_DIR := src TEST_DIR := tests diff --git a/scripts/example.py b/scripts/example.py index 68704eb..ea409a4 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -11,20 +11,23 @@ def main(): # 矩阵维度 - M, K, N = 128, 128, 128 + M, K, N = 12, 12, 12 # 随机数生成器 rng = np.random.default_rng(2025) # 生成随机 INT4 权重(-8 到 +7) weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8) + print("weights:") + print(weights) # 將有符號 int4 轉換為無符號表示 weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8) - - # 生成随机 FP16 激活 + print("weights_unsigned:") + print(weights_unsigned) + # 生成随机 FP16 激活(使用標準正態分佈) activations = rng.standard_normal(size=(K, N)).astype(np.float16) - # 随机 bias(FP32) - bias = rng.standard_normal(size=N).astype(np.float32) + # 随机 bias(FP32,範圍也限制在合理範圍內) + bias = rng.uniform(-1, 1, size=N).astype(np.float32) # 扁平化并转为 Python 列表 w_flat = weights_unsigned.flatten().tolist() @@ -34,11 +37,15 @@ def main(): # === 1. 基准参考输出 === gemm_ref = mpgemm.Engine("naive") ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N) + print("ref_flat:") + print(ref_flat) # === 2. LUT 后端输出 === gemm_lut = mpgemm.Engine("lut") gemm_lut.generate_lut(bit_width=4) out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N) + print("out_flat:") + print(out_flat) # === 3. 后处理示例 === out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list) diff --git a/src/gemm_engine.hpp b/src/gemm_engine.hpp index 86bd962..c9b3db6 100644 --- a/src/gemm_engine.hpp +++ b/src/gemm_engine.hpp @@ -26,9 +26,9 @@ class Engine { : lut(nullptr) { if (backend_str == "naive") backend = Backend::Naive; - else if (backend_str == "lut") backend = Backend::LUT; + else if (backend_str == "lut") backend = Backend::LUT; #ifdef USE_MKL - else if (backend_str == "mkl") backend = Backend::MKL; + else if (backend_str == "mkl") backend = Backend::MKL; #endif else throw std::invalid_argument("Unknown backend: " + backend_str); } @@ -79,15 +79,19 @@ class Engine { for (int j = 0; j < K; ++j) Wq.set(i,j, Wflat[size_t(i)*K + j]); auto Wu = unpack_int4(Wq); - // 把 Activation floats 截断、cast 为 uint8 索引,並轉換為有符號 + + // 把 Activation floats 截断、cast 为 uint8 索引 std::vector Au(size_t(K)*N); for (int i = 0; i < K; ++i) for (int j = 0; j < N; ++j) { float val = Aflat[size_t(i)*N + j]; + // 將浮點數轉換為有符號 int4 範圍 int q = std::lround(val); q = std::clamp(q, -8, 7); // 限制在有符號 int4 範圍 + // 轉換為無符號表示 Au[size_t(i)*N + j] = uint8_t(q < 0 ? q + 16 : q); } + auto Ci = matmul_lut_fast(Wu, Au, M, K, N, *lut); for (int i = 0; i < M; ++i) for (int j = 0; j < N; ++j) diff --git a/src/lut_utils.hpp b/src/lut_utils.hpp index 2623a67..c367f35 100644 --- a/src/lut_utils.hpp +++ b/src/lut_utils.hpp @@ -4,6 +4,8 @@ #include #include // for posix_memalign, free #include +#include +#include // 對齊分配器:以 Align 對齊分配 template @@ -40,42 +42,97 @@ struct AlignedAllocator { // lookup table for product of two integers // LUT:64-byte 對齊,並將每列 padding 到 8 的倍數 -template > +template > class ProductLookupTable { public: using WeightType = W; using ActivationType = A; using ProductType = P; - ProductLookupTable(std::size_t w_range, std::size_t a_range) - : w_range_(w_range), - a_range_(a_range), - padded_a_range_(((a_range + 7) / 8) * 8), - table_(w_range * padded_a_range_) // 用 padding 後的尺寸 + // 建構内部初始化,接受權重量化層數與 activation 長度 + ProductLookupTable(std::size_t weight_levels, std::size_t a_range) + : weight_levels_(weight_levels), + a_range_(a_range), + padded_a_range_(((a_range + 7) / 8) * 8), + table_(weight_levels * padded_a_range_) { - for (std::size_t w = 0; w < w_range_; ++w) - for (std::size_t a = 0; a < a_range_; ++a) - table_[w * padded_a_range_ + a] = - static_cast(w) * static_cast(a); + // Default initialization: scalar LUT for weight × activation (raw) values + for (std::size_t w = 0; w < weight_levels_; ++w) { + int signed_w = static_cast(w < (weight_levels_/2) ? w : w - weight_levels_); + P* row_ptr = &table_[w * padded_a_range_]; + for (std::size_t a = 0; a < a_range_; ++a) { + // Activation index a treated as raw ActivationType + ActivationType act_val = static_cast(a); + int64_t prod = static_cast(signed_w) * static_cast(act_val); + // Saturate to ProductType range + if (prod > std::numeric_limits

::max()) prod = std::numeric_limits

::max(); + else if (prod < std::numeric_limits

::min()) prod = std::numeric_limits

::min(); + row_ptr[a] = static_cast

(prod); + } + // padding region left uninitialized + } } - // ---- scalar accessors ---- - inline ProductType get(W w, A a) const noexcept { - return table_[size_t(w) * padded_a_range_ + size_t(a)]; + /** + * fill_from_activation: + * 根據輸入的 activation row (act_row, 長度 = a_range_), + * 為所有 weight level 預先計算並填充查表數據: + * 對於每個 weight w (0..weight_levels_-1), + * 1. 轉為有符號值 signed_w + * 2. 遍歷所有 activation 值 act_row[a] + * 3. 計算 signed_w * act_row[a],並做飽和處理 + * 4. 存入 table_[w * padded_a_range_ + a] + * + * 這樣,後續 matmul 可直接對每個 weight lookup 整列乘積向量, + * 而不需即時計算乘法。 + */ + void fill_from_activation(const ActivationType* act_row) { + for (std::size_t w = 0; w < weight_levels_; ++w) { + int signed_w = static_cast(w < (weight_levels_/2) ? w : w - weight_levels_); + P* row_ptr = &table_[w * padded_a_range_]; + for (std::size_t a = 0; a < a_range_; ++a) { + // 讀取 activation 值;若是 uint8_t,則認為是 quantized 4-bit 需做 signed 轉換 + P act_val; + if constexpr (std::is_same::value) { + uint8_t raw = act_row[a]; + // two's complement mapping for 4-bit + act_val = static_cast

( raw < (weight_levels_/2) ? raw : raw - weight_levels_ ); + } else { + act_val = static_cast

(act_row[a]); + } + // 計算乘積 + + int64_t prod_int = static_cast(signed_w) * static_cast(act_val); + // 飽和處理 + if (prod_int > std::numeric_limits

::max()) prod_int = std::numeric_limits

::max(); + else if (prod_int < std::numeric_limits

::min()) prod_int = std::numeric_limits

::min(); + row_ptr[a] = static_cast

(prod_int); + } + // padding 部分不必初始化 + } + } + + // 取得 weight-level w 的向量查表指標 + inline const P* get_row(std::size_t w) const noexcept { + return &table_[w * padded_a_range_]; } - inline ProductType operator()(W w, A a) const noexcept { + + // 元素查表,保留舊接口 + inline ProductType get(std::size_t w, std::size_t a) const noexcept { + return table_[w * padded_a_range_ + a]; + } + inline ProductType operator()(std::size_t w, std::size_t a) const noexcept { return get(w, a); } - // ---- helpers ---- - const ProductType* data() const noexcept { return table_.data(); } - std::size_t row_stride() const noexcept { return padded_a_range_; } - std::size_t weight_range() const noexcept { return w_range_; } - std::size_t activation_range() const noexcept { return a_range_; } - std::size_t lut_size_bytes() const noexcept { return table_.size() * sizeof(ProductType);} + // table 資料屬性 + const P* data() const noexcept { return table_.data(); } + std::size_t row_stride() const noexcept { return padded_a_range_; } + std::size_t weight_levels() const noexcept { return weight_levels_; } + std::size_t activation_range() const noexcept { return a_range_; } + std::size_t lut_size_bytes() const noexcept { return table_.size() * sizeof(P); } private: - std::size_t w_range_, a_range_, padded_a_range_; - std::vector> table_; // flat buffer (w * a_range + a) -}; \ No newline at end of file + std::size_t weight_levels_, a_range_, padded_a_range_; + std::vector> table_; // flat buffer (w * padded + a) +}; diff --git a/src/matrix_ops.hpp b/src/matrix_ops.hpp index 3651a28..a625726 100644 --- a/src/matrix_ops.hpp +++ b/src/matrix_ops.hpp @@ -8,6 +8,7 @@ #include #include #include +#include // ============================================================= // Helper: unpack a Matrix<> that uses Int4Storage into a @@ -135,105 +136,53 @@ Matrix matmul_mkl(const Matrix& A, const Matrix& B) { // Works with or without AVX2 (scalar fallback). // ============================================================= -auto matmul_lut_fast(const std::vector& Au, - const std::vector& Bu, - size_t M, size_t K, size_t N, - const ProductLookupTable& lut, - size_t block_size = 64, // 增加區塊大小以減少同步開銷 - size_t num_threads = 4) -{ - const int32_t* lut_ptr = lut.data(); - const int32_t stride = static_cast(lut.row_stride()); +// LUT-based mixed-precision GEMM kernel +template +auto matmul_lut_fast(const std::vector& W, + const std::vector& A_mat, + size_t M, size_t K, size_t N, + ProductLookupTable& lut, + size_t block_size = 64, + size_t num_threads = 4) { + // Result matrix Matrix> C(M, N); std::vector threads; std::mutex mtx; - - // 預取 LUT 到 L1 cache - for (size_t i = 0; i < 16; ++i) { - for (size_t j = 0; j < 16; ++j) { - _mm_prefetch(reinterpret_cast(&lut_ptr[i * stride + j]), _MM_HINT_T0); - } - } - - // 計算每個執行緒處理的行數 size_t rows_per_thread = (M + num_threads - 1) / num_threads; - // 為每個執行緒分配工作 for (size_t t = 0; t < num_threads; ++t) { threads.emplace_back([&, t]() { - size_t start_row = t * rows_per_thread; - size_t end_row = std::min(start_row + rows_per_thread, M); - - // 為每個執行緒創建局部結果矩陣 - Matrix> local_C(M, N); - - // 分塊處理 - for (size_t i = start_row; i < end_row; i += block_size) { - size_t i_end = std::min(i + block_size, end_row); - - for (size_t j = 0; j < N; j += block_size) { - size_t j_end = std::min(j + block_size, N); - - for (size_t k = 0; k < K; k += block_size) { - size_t k_end = std::min(k + block_size, K); - - // 處理當前區塊 + size_t row_start = t * rows_per_thread; + size_t row_end = std::min(row_start + rows_per_thread, M); + Matrix> localC(M, N); + for (size_t i = row_start; i < row_end; i += block_size) { + size_t i_end = std::min(i + block_size, row_end); + for (size_t k = 0; k < K; k += block_size) { + size_t k_end = std::min(k + block_size, K); + // For each k in this block, rebuild LUT and accumulate + for (size_t kk = k; kk < k_end; ++kk) { + const A* act_row = &A_mat[kk * N]; + lut.fill_from_activation(act_row); + // Accumulate for each row i for (size_t ii = i; ii < i_end; ++ii) { - const uint8_t* rowA = &Au[ii * K]; - - for (size_t jj = j; jj < j_end; ++jj) { - int32_t acc = 0; - -#if defined(__AVX2__) - // 使用 AVX2 處理 8 個元素 - for (size_t kk = k; kk + 7 < k_end; kk += 8) { - __m128i w8 = _mm_loadl_epi64(reinterpret_cast(rowA + kk)); - __m128i a8 = _mm_loadl_epi64(reinterpret_cast(&Bu[kk * N + jj])); - - __m256i w32 = _mm256_cvtepu8_epi32(w8); - __m256i a32 = _mm256_cvtepu8_epi32(a8); - __m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, _mm256_set1_epi32(stride)), a32); - - // 預取下一個 LUT 值 - _mm_prefetch(reinterpret_cast(&lut_ptr[_mm256_extract_epi32(idx, 0)]), _MM_HINT_T0); - - __m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4); - - // 水平加總 - __m128i low = _mm256_castsi256_si128(vals); - __m128i high = _mm256_extracti128_si256(vals, 1); - __m128i sum = _mm_add_epi32(low, high); - sum = _mm_hadd_epi32(sum, sum); - sum = _mm_hadd_epi32(sum, sum); - acc += _mm_cvtsi128_si32(sum); - } -#endif - // 處理剩餘元素 - for (size_t kk = k + ((k_end - k) & ~7); kk < k_end; ++kk) { - acc += lut_ptr[rowA[kk] * stride + Bu[kk * N + jj]]; - } - - local_C.set(ii, jj, local_C.at(ii, jj) + acc); + uint8_t q = W[ii * K + kk]; + const int32_t* lut_row = lut.get_row(q); + for (size_t j = 0; j < N; ++j) { + localC.set(ii, j, localC.at(ii, j) + lut_row[j]); } } } } } - - // 合併結果 + // Merge into C std::lock_guard lock(mtx); - for (size_t i = start_row; i < end_row; ++i) { + for (size_t ii = row_start; ii < row_end; ++ii) { for (size_t j = 0; j < N; ++j) { - C.set(i, j, C.at(i, j) + local_C.at(i, j)); + C.set(ii, j, C.at(ii, j) + localC.at(ii, j)); } } }); } - - // 等待所有執行緒完成 - for (auto& thread : threads) { - thread.join(); - } - + for (auto& thr : threads) thr.join(); return C; } \ No newline at end of file diff --git a/tests/test_correctness.cpp b/tests/test_correctness.cpp index 7719a3d..21e18c8 100644 --- a/tests/test_correctness.cpp +++ b/tests/test_correctness.cpp @@ -226,18 +226,35 @@ bool run_int4_fast_test(){ using Mat4C = Matrix; std::mt19937 rng(42); - std::uniform_int_distribution d4(0,15); + // 修改分佈範圍為 -8 到 7 + std::uniform_int_distribution d4(-8,7); Mat4R A4(M,K); Mat4C B4(K,N); - for(int i=0;i naive matmul auto Au = unpack_int4(A4); auto Bu = unpack_int4(B4); Matrix> Au_mat(M,K), Bu_mat(K,N); - for(int i=0;i Date: Thu, 15 May 2025 14:33:48 +0800 Subject: [PATCH 4/4] remove debug message printout --- scripts/example.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/scripts/example.py b/scripts/example.py index ea409a4..443c76d 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -18,12 +18,10 @@ def main(): # 生成随机 INT4 权重(-8 到 +7) weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8) - print("weights:") - print(weights) + # 將有符號 int4 轉換為無符號表示 weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8) - print("weights_unsigned:") - print(weights_unsigned) + # 生成随机 FP16 激活(使用標準正態分佈) activations = rng.standard_normal(size=(K, N)).astype(np.float16) # 随机 bias(FP32,範圍也限制在合理範圍內) @@ -37,15 +35,13 @@ def main(): # === 1. 基准参考输出 === gemm_ref = mpgemm.Engine("naive") ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N) - print("ref_flat:") - print(ref_flat) + # === 2. LUT 后端输出 === gemm_lut = mpgemm.Engine("lut") gemm_lut.generate_lut(bit_width=4) out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N) - print("out_flat:") - print(out_flat) + # === 3. 后处理示例 === out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list)