Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CXX := g++
CXXFLAGS := -std=c++17 -O2 -Wall -I./src -march=native
CXXFLAGS := -std=c++17 -O2 -Wall -I./src -march=native -fopenmp

PYBIND11_INC := $(shell python3 -m pybind11 --includes)
PYEXT := $(shell python3-config --extension-suffix)
Expand All @@ -14,6 +14,10 @@ ifeq ($(USE_MKL),1)
endif
# --------------------

# ---- OpenMP toggle ----
LDLIBS += -lgomp
# --------------------


SRC_DIR := src
TEST_DIR := tests
Expand Down
2 changes: 1 addition & 1 deletion doc/proposal.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ common precision, as hardware lacks native support for mixed-precision matrix
multiplication (mpGEMM). Some recent research suggests using lookup tables
(LUTs) to replace dequantization, further reducing computational overhead.

![LUT](./img/lut.png)
![LUT](../img/lut.png)

## Problem to Solve

Expand Down
20 changes: 13 additions & 7 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,38 @@

def main():
# 矩阵维度
M, K, N = 128, 128, 128
M, K, N = 12, 12, 12

# 随机数生成器
rng = np.random.default_rng(2025)

# 生成随机 INT4 权重(0-15)
weights = rng.integers(0, 16, size=(M, K), dtype=np.uint8)
# 生成随机 FP16 激活
# 生成随机 INT4 权重(-8 到 +7)
weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8)

# 將有符號 int4 轉換為無符號表示
weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8)

# 生成随机 FP16 激活(使用標準正態分佈)
activations = rng.standard_normal(size=(K, N)).astype(np.float16)
# 随机 bias(FP32)
bias = rng.standard_normal(size=N).astype(np.float32)
# 随机 bias(FP32,範圍也限制在合理範圍內
bias = rng.uniform(-1, 1, size=N).astype(np.float32)

# 扁平化并转为 Python 列表
w_flat = weights.flatten().tolist()
w_flat = weights_unsigned.flatten().tolist()
a_flat = activations.flatten().astype(float).tolist()
bias_list = bias.tolist()

# === 1. 基准参考输出 ===
gemm_ref = mpgemm.Engine("naive")
ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N)


# === 2. LUT 后端输出 ===
gemm_lut = mpgemm.Engine("lut")
gemm_lut.generate_lut(bit_width=4)
out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N)


# === 3. 后处理示例 ===
out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list)
out_relu = gemm_lut.apply_activation(out_biased, M, N, Activation.ReLU)
Expand Down
24 changes: 17 additions & 7 deletions src/gemm_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class Engine {
: lut(nullptr)
{
if (backend_str == "naive") backend = Backend::Naive;
else if (backend_str == "lut") backend = Backend::LUT;
else if (backend_str == "lut") backend = Backend::LUT;
#ifdef USE_MKL
else if (backend_str == "mkl") backend = Backend::MKL;
else if (backend_str == "mkl") backend = Backend::MKL;
#endif
else throw std::invalid_argument("Unknown backend: " + backend_str);
}
Expand All @@ -55,10 +55,12 @@ class Engine {
switch (backend) {
case Backend::Naive: {
Matrix<int,RowMajor,PlainStorage<int>> Wi(M, K), Ai(K, N);
// 填充 Wi, Ai
// 填充 Wi, Ai,將 uint8_t 轉換為有符號 int
for (int i = 0; i < M; ++i)
for (int j = 0; j < K; ++j)
Wi.set(i, j, (int)Wflat[size_t(i)*K + j]);
for (int j = 0; j < K; ++j) {
int val = Wflat[size_t(i)*K + j];
Wi.set(i, j, val < 8 ? val : val - 16); // 轉換為有符號
}
for (int i = 0; i < K; ++i)
for (int j = 0; j < N; ++j)
Ai.set(i, j, (int)std::lround(Aflat[size_t(i)*N + j]));
Expand All @@ -77,11 +79,19 @@ class Engine {
for (int j = 0; j < K; ++j)
Wq.set(i,j, Wflat[size_t(i)*K + j]);
auto Wu = unpack_int4(Wq);

// 把 Activation floats 截断、cast 为 uint8 索引
std::vector<uint8_t> Au(size_t(K)*N);
for (int i = 0; i < K; ++i)
for (int j = 0; j < N; ++j)
Au[size_t(i)*N + j] = uint8_t(std::lround(Aflat[size_t(i)*N + j]));
for (int j = 0; j < N; ++j) {
float val = Aflat[size_t(i)*N + j];
// 將浮點數轉換為有符號 int4 範圍
int q = std::lround(val);
q = std::clamp(q, -8, 7); // 限制在有符號 int4 範圍
// 轉換為無符號表示
Au[size_t(i)*N + j] = uint8_t(q < 0 ? q + 16 : q);
}

auto Ci = matmul_lut_fast(Wu, Au, M, K, N, *lut);
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
Expand Down
105 changes: 81 additions & 24 deletions src/lut_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <type_traits>
#include <cstdlib> // for posix_memalign, free
#include <memory>
#include <limits>
#include <iostream>

// 對齊分配器:以 Align 對齊分配
template<typename T, std::size_t Align>
Expand Down Expand Up @@ -40,42 +42,97 @@ struct AlignedAllocator {

// lookup table for product of two integers
// LUT:64-byte 對齊,並將每列 padding 到 8 的倍數
template <typename W, typename A,
typename P = std::common_type_t<W, A>>
template <typename W, typename A, typename P = std::common_type_t<W, A>>
class ProductLookupTable {
public:
using WeightType = W;
using ActivationType = A;
using ProductType = P;

ProductLookupTable(std::size_t w_range, std::size_t a_range)
: w_range_(w_range),
a_range_(a_range),
padded_a_range_(((a_range + 7) / 8) * 8),
table_(w_range * padded_a_range_) // 用 padding 後的尺寸
// 建構内部初始化,接受權重量化層數與 activation 長度
ProductLookupTable(std::size_t weight_levels, std::size_t a_range)
: weight_levels_(weight_levels),
a_range_(a_range),
padded_a_range_(((a_range + 7) / 8) * 8),
table_(weight_levels * padded_a_range_)
{
for (std::size_t w = 0; w < w_range_; ++w)
for (std::size_t a = 0; a < a_range_; ++a)
table_[w * padded_a_range_ + a] =
static_cast<ProductType>(w) * static_cast<ProductType>(a);
// Default initialization: scalar LUT for weight × activation (raw) values
for (std::size_t w = 0; w < weight_levels_; ++w) {
int signed_w = static_cast<int>(w < (weight_levels_/2) ? w : w - weight_levels_);
P* row_ptr = &table_[w * padded_a_range_];
for (std::size_t a = 0; a < a_range_; ++a) {
// Activation index a treated as raw ActivationType
ActivationType act_val = static_cast<ActivationType>(a);
int64_t prod = static_cast<int64_t>(signed_w) * static_cast<int64_t>(act_val);
// Saturate to ProductType range
if (prod > std::numeric_limits<P>::max()) prod = std::numeric_limits<P>::max();
else if (prod < std::numeric_limits<P>::min()) prod = std::numeric_limits<P>::min();
row_ptr[a] = static_cast<P>(prod);
}
// padding region left uninitialized
}
}

// ---- scalar accessors ----
inline ProductType get(W w, A a) const noexcept {
return table_[size_t(w) * padded_a_range_ + size_t(a)];
/**
* fill_from_activation:
* 根據輸入的 activation row (act_row, 長度 = a_range_),
* 為所有 weight level 預先計算並填充查表數據:
* 對於每個 weight w (0..weight_levels_-1),
* 1. 轉為有符號值 signed_w
* 2. 遍歷所有 activation 值 act_row[a]
* 3. 計算 signed_w * act_row[a],並做飽和處理
* 4. 存入 table_[w * padded_a_range_ + a]
*
* 這樣,後續 matmul 可直接對每個 weight lookup 整列乘積向量,
* 而不需即時計算乘法。
*/
void fill_from_activation(const ActivationType* act_row) {
for (std::size_t w = 0; w < weight_levels_; ++w) {
int signed_w = static_cast<int>(w < (weight_levels_/2) ? w : w - weight_levels_);
P* row_ptr = &table_[w * padded_a_range_];
for (std::size_t a = 0; a < a_range_; ++a) {
// 讀取 activation 值;若是 uint8_t,則認為是 quantized 4-bit 需做 signed 轉換
P act_val;
if constexpr (std::is_same<ActivationType, uint8_t>::value) {
uint8_t raw = act_row[a];
// two's complement mapping for 4-bit
act_val = static_cast<P>( raw < (weight_levels_/2) ? raw : raw - weight_levels_ );
} else {
act_val = static_cast<P>(act_row[a]);
}
// 計算乘積

int64_t prod_int = static_cast<int64_t>(signed_w) * static_cast<int64_t>(act_val);
// 飽和處理
if (prod_int > std::numeric_limits<P>::max()) prod_int = std::numeric_limits<P>::max();
else if (prod_int < std::numeric_limits<P>::min()) prod_int = std::numeric_limits<P>::min();
row_ptr[a] = static_cast<P>(prod_int);
}
// padding 部分不必初始化
}
}

// 取得 weight-level w 的向量查表指標
inline const P* get_row(std::size_t w) const noexcept {
return &table_[w * padded_a_range_];
}
inline ProductType operator()(W w, A a) const noexcept {

// 元素查表,保留舊接口
inline ProductType get(std::size_t w, std::size_t a) const noexcept {
return table_[w * padded_a_range_ + a];
}
inline ProductType operator()(std::size_t w, std::size_t a) const noexcept {
return get(w, a);
}

// ---- helpers ----
const ProductType* data() const noexcept { return table_.data(); }
std::size_t row_stride() const noexcept { return padded_a_range_; }
std::size_t weight_range() const noexcept { return w_range_; }
std::size_t activation_range() const noexcept { return a_range_; }
std::size_t lut_size_bytes() const noexcept { return table_.size() * sizeof(ProductType);}
// table 資料屬性
const P* data() const noexcept { return table_.data(); }
std::size_t row_stride() const noexcept { return padded_a_range_; }
std::size_t weight_levels() const noexcept { return weight_levels_; }
std::size_t activation_range() const noexcept { return a_range_; }
std::size_t lut_size_bytes() const noexcept { return table_.size() * sizeof(P); }

private:
std::size_t w_range_, a_range_, padded_a_range_;
std::vector<ProductType, AlignedAllocator<ProductType,64>> table_; // flat buffer (w * a_range + a)
};
std::size_t weight_levels_, a_range_, padded_a_range_;
std::vector<P, AlignedAllocator<P, 64>> table_; // flat buffer (w * padded + a)
};
111 changes: 30 additions & 81 deletions src/matrix_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <immintrin.h>
#include <thread>
#include <mutex>
#include <iostream>

// =============================================================
// Helper: unpack a Matrix<> that uses Int4Storage into a
Expand Down Expand Up @@ -135,105 +136,53 @@ Matrix<T> matmul_mkl(const Matrix<T>& A, const Matrix<T>& B) {
// Works with or without AVX2 (scalar fallback).
// =============================================================

auto matmul_lut_fast(const std::vector<uint8_t>& Au,
const std::vector<uint8_t>& Bu,
size_t M, size_t K, size_t N,
const ProductLookupTable<uint8_t, uint8_t, int32_t>& lut,
size_t block_size = 64, // 增加區塊大小以減少同步開銷
size_t num_threads = 4)
{
const int32_t* lut_ptr = lut.data();
const int32_t stride = static_cast<int32_t>(lut.row_stride());
// LUT-based mixed-precision GEMM kernel
template <typename A>
auto matmul_lut_fast(const std::vector<uint8_t>& W,
const std::vector<A>& A_mat,
size_t M, size_t K, size_t N,
ProductLookupTable<uint8_t, A, int32_t>& lut,
size_t block_size = 64,
size_t num_threads = 4) {
// Result matrix
Matrix<int32_t, RowMajor, PlainStorage<int32_t>> C(M, N);
std::vector<std::thread> threads;
std::mutex mtx;

// 預取 LUT 到 L1 cache
for (size_t i = 0; i < 16; ++i) {
for (size_t j = 0; j < 16; ++j) {
_mm_prefetch(reinterpret_cast<const char*>(&lut_ptr[i * stride + j]), _MM_HINT_T0);
}
}

// 計算每個執行緒處理的行數
size_t rows_per_thread = (M + num_threads - 1) / num_threads;

// 為每個執行緒分配工作
for (size_t t = 0; t < num_threads; ++t) {
threads.emplace_back([&, t]() {
size_t start_row = t * rows_per_thread;
size_t end_row = std::min(start_row + rows_per_thread, M);

// 為每個執行緒創建局部結果矩陣
Matrix<int32_t, RowMajor, PlainStorage<int32_t>> local_C(M, N);

// 分塊處理
for (size_t i = start_row; i < end_row; i += block_size) {
size_t i_end = std::min(i + block_size, end_row);

for (size_t j = 0; j < N; j += block_size) {
size_t j_end = std::min(j + block_size, N);

for (size_t k = 0; k < K; k += block_size) {
size_t k_end = std::min(k + block_size, K);

// 處理當前區塊
size_t row_start = t * rows_per_thread;
size_t row_end = std::min(row_start + rows_per_thread, M);
Matrix<int32_t, RowMajor, PlainStorage<int32_t>> localC(M, N);
for (size_t i = row_start; i < row_end; i += block_size) {
size_t i_end = std::min(i + block_size, row_end);
for (size_t k = 0; k < K; k += block_size) {
size_t k_end = std::min(k + block_size, K);
// For each k in this block, rebuild LUT and accumulate
for (size_t kk = k; kk < k_end; ++kk) {
const A* act_row = &A_mat[kk * N];
lut.fill_from_activation(act_row);
// Accumulate for each row i
for (size_t ii = i; ii < i_end; ++ii) {
const uint8_t* rowA = &Au[ii * K];

for (size_t jj = j; jj < j_end; ++jj) {
int32_t acc = 0;

#if defined(__AVX2__)
// 使用 AVX2 處理 8 個元素
for (size_t kk = k; kk + 7 < k_end; kk += 8) {
__m128i w8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(rowA + kk));
__m128i a8 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&Bu[kk * N + jj]));

__m256i w32 = _mm256_cvtepu8_epi32(w8);
__m256i a32 = _mm256_cvtepu8_epi32(a8);
__m256i idx = _mm256_add_epi32(_mm256_mullo_epi32(w32, _mm256_set1_epi32(stride)), a32);

// 預取下一個 LUT 值
_mm_prefetch(reinterpret_cast<const char*>(&lut_ptr[_mm256_extract_epi32(idx, 0)]), _MM_HINT_T0);

__m256i vals = _mm256_i32gather_epi32(lut_ptr, idx, 4);

// 水平加總
__m128i low = _mm256_castsi256_si128(vals);
__m128i high = _mm256_extracti128_si256(vals, 1);
__m128i sum = _mm_add_epi32(low, high);
sum = _mm_hadd_epi32(sum, sum);
sum = _mm_hadd_epi32(sum, sum);
acc += _mm_cvtsi128_si32(sum);
}
#endif
// 處理剩餘元素
for (size_t kk = k + ((k_end - k) & ~7); kk < k_end; ++kk) {
acc += lut_ptr[rowA[kk] * stride + Bu[kk * N + jj]];
}

local_C.set(ii, jj, local_C.at(ii, jj) + acc);
uint8_t q = W[ii * K + kk];
const int32_t* lut_row = lut.get_row(q);
for (size_t j = 0; j < N; ++j) {
localC.set(ii, j, localC.at(ii, j) + lut_row[j]);
}
}
}
}
}

// 合併結果
// Merge into C
std::lock_guard<std::mutex> lock(mtx);
for (size_t i = start_row; i < end_row; ++i) {
for (size_t ii = row_start; ii < row_end; ++ii) {
for (size_t j = 0; j < N; ++j) {
C.set(i, j, C.at(i, j) + local_C.at(i, j));
C.set(ii, j, C.at(ii, j) + localC.at(ii, j));
}
}
});
}

// 等待所有執行緒完成
for (auto& thread : threads) {
thread.join();
}

for (auto& thr : threads) thr.join();
return C;
}
Loading