Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ matrix_ops: $(TARGET_MATRIX_OPS)
@echo "\nRunning LUT version..."
@./$(TARGET_MATRIX_OPS) lut

# 方便的命令
matrix_ops_float: $(TARGET_MATRIX_OPS)
./$(TARGET_MATRIX_OPS) float

Expand Down
22 changes: 9 additions & 13 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,53 @@
import os
import sys

# 确保项目根目录在 PYTHONPATH
sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, '..')))

import numpy as np
import mpgemm
from mpgemm import Activation

def main():
# 矩阵维度

M, K, N = 12, 12, 12

# 随机数生成器
rng = np.random.default_rng(2025)

# 生成随机 INT4 权重(-8 到 +7)
# generate random int4 weights in range [-8, 7]
weights = rng.integers(-8, 8, size=(M, K), dtype=np.int8)

# 將有符號 int4 轉換為無符號表示
weights_unsigned = np.where(weights < 0, weights + 16, weights).astype(np.uint8)

# 生成随机 FP16 激活(使用標準正態分佈)
# generate random activations
activations = rng.standard_normal(size=(K, N)).astype(np.float16)
# 随机 bias(FP32,範圍也限制在合理範圍內)

bias = rng.uniform(-1, 1, size=N).astype(np.float32)

# 扁平化并转为 Python 列表
w_flat = weights_unsigned.flatten().tolist()
a_flat = activations.flatten().astype(float).tolist()
bias_list = bias.tolist()

# === 1. 基准参考输出 ===
# === 1. Baseline: Naive GEMM ===
gemm_ref = mpgemm.Engine("naive")
ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N)


# === 2. LUT 后端输出 ===
# === 2. LUT GEMM ===
gemm_lut = mpgemm.Engine("lut")
gemm_lut.generate_lut(bit_width=4)
out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N)


# === 3. 后处理示例 ===
# === 3. Post-processing ===
out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list)
out_relu = gemm_lut.apply_activation(out_biased, M, N, Activation.ReLU)

# 还原成矩阵
# === 4. Output ===
output = np.array(out_relu, dtype=np.float32).reshape(M, N)
print(f"Output shape: {output.shape}")
print("Sample row[0, :5]:", output[0, :5])

# === 4. 误差分析示例 ===
# === 5. Error measurement ===
stats = mpgemm.measure_error(ref_flat, out_flat)
print(f"\nError relative to naive:")
print(f" MSE = {stats['mse']:.6f}")
Expand Down
14 changes: 3 additions & 11 deletions src/gemm_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ enum class Backend {

class Engine {
public:
// 构造:支持 "naive"/"lut" (和 "mkl")
Engine(const std::string &backend_str)
: lut(nullptr)
{
Expand All @@ -33,7 +32,6 @@ class Engine {
else throw std::invalid_argument("Unknown backend: " + backend_str);
}

// 仅在 LUT 模式下调用,bit_width 一般传 4
void generate_lut(int bit_width) {
if (backend != Backend::LUT)
throw std::runtime_error("generate_lut only valid for LUT backend");
Expand All @@ -43,7 +41,6 @@ class Engine {
lut = std::make_unique<ProductLookupTable<uint8_t,uint8_t,int32_t>>(range, range);
}

// matmul: Wflat 是 uint8_t(int4 存 raw),Aflat 是 float(list)
std::vector<float> matmul(
const std::vector<uint8_t>& Wflat,
const std::vector<float>& Aflat,
Expand All @@ -55,16 +52,14 @@ class Engine {
switch (backend) {
case Backend::Naive: {
Matrix<int,RowMajor,PlainStorage<int>> Wi(M, K), Ai(K, N);
// 填充 Wi, Ai,將 uint8_t 轉換為有符號 int
for (int i = 0; i < M; ++i)
for (int j = 0; j < K; ++j) {
int val = Wflat[size_t(i)*K + j];
Wi.set(i, j, val < 8 ? val : val - 16); // 轉換為有符號
Wi.set(i, j, val < 8 ? val : val - 16);
}
for (int i = 0; i < K; ++i)
for (int j = 0; j < N; ++j)
Ai.set(i, j, (int)std::lround(Aflat[size_t(i)*N + j]));
// 调用全局 ::matmul
auto C = ::matmul(Wi, Ai);
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
Expand All @@ -73,22 +68,19 @@ class Engine {
}
case Backend::LUT: {
if (!lut) throw std::runtime_error("LUT not generated");
// 从 raw uint8 构造 Int4Storage 矩阵并 unpack

Matrix<uint8_t,RowMajor,Int4Storage> Wq(M,K);
for (int i = 0; i < M; ++i)
for (int j = 0; j < K; ++j)
Wq.set(i,j, Wflat[size_t(i)*K + j]);
auto Wu = unpack_int4(Wq);

// 把 Activation floats 截断、cast 为 uint8 索引
std::vector<uint8_t> Au(size_t(K)*N);
for (int i = 0; i < K; ++i)
for (int j = 0; j < N; ++j) {
float val = Aflat[size_t(i)*N + j];
// 將浮點數轉換為有符號 int4 範圍
int q = std::lround(val);
q = std::clamp(q, -8, 7); // 限制在有符號 int4 範圍
// 轉換為無符號表示
q = std::clamp(q, -8, 7);
Au[size_t(i)*N + j] = uint8_t(q < 0 ? q + 16 : q);
}

Expand Down
115 changes: 54 additions & 61 deletions src/lut_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
#pragma once

#include <vector>
#include <cstddef>
#include <type_traits>
#include <cstdlib> // for posix_memalign, free
#include <cstdlib> // posix_memalign, free
#include <memory>
#include <limits>
#include <iostream>
#include <cstdint>

// 對齊分配器:以 Align 對齊分配
// Aligned allocator with compile-time check
template<typename T, std::size_t Align>
struct AlignedAllocator {
static_assert((Align & (Align - 1)) == 0 && Align >= alignof(T),
"Align must be power-of-two and at least alignof(T)");

using value_type = T;
using pointer = T*;
using const_pointer = const T*;
Expand Down Expand Up @@ -40,92 +44,57 @@ struct AlignedAllocator {
}
};

// lookup table for product of two integers
// LUT:64-byte 對齊,並將每列 padding 到 8 的倍數
// Lookup table for product of two integers, 64-byte aligned
// W: weight type, A: activation type, P: product type
template <typename W, typename A, typename P = std::common_type_t<W, A>>
class ProductLookupTable {
public:
using WeightType = W;
using ActivationType = A;
using ProductType = P;

// 建構内部初始化,接受權重量化層數與 activation 長度
ProductLookupTable(std::size_t weight_levels, std::size_t a_range)
: weight_levels_(weight_levels),
a_range_(a_range),
padded_a_range_(((a_range + 7) / 8) * 8),
table_(weight_levels * padded_a_range_)
{
// Default initialization: scalar LUT for weight × activation (raw) values
for (std::size_t w = 0; w < weight_levels_; ++w) {
int signed_w = static_cast<int>(w < (weight_levels_/2) ? w : w - weight_levels_);
P* row_ptr = &table_[w * padded_a_range_];
for (std::size_t a = 0; a < a_range_; ++a) {
// Activation index a treated as raw ActivationType
ActivationType act_val = static_cast<ActivationType>(a);
int64_t prod = static_cast<int64_t>(signed_w) * static_cast<int64_t>(act_val);
// Saturate to ProductType range
if (prod > std::numeric_limits<P>::max()) prod = std::numeric_limits<P>::max();
else if (prod < std::numeric_limits<P>::min()) prod = std::numeric_limits<P>::min();
row_ptr[a] = static_cast<P>(prod);
// Default: build LUT with raw indices as activations
fill_impl([&](std::size_t a) -> int64_t {
int64_t raw = static_cast<int64_t>(a);
// two's-complement for 4-bit (uint8_t) activations
int64_t act = raw;
if constexpr (std::is_same<ActivationType, uint8_t>::value) {
act = (raw < static_cast<int64_t>(weight_levels_ / 2))
? raw : (raw - static_cast<int64_t>(weight_levels_));
}
// padding region left uninitialized
}
return act;
});
}

/**
* fill_from_activation:
* 根據輸入的 activation row (act_row, 長度 = a_range_),
* 為所有 weight level 預先計算並填充查表數據:
* 對於每個 weight w (0..weight_levels_-1),
* 1. 轉為有符號值 signed_w
* 2. 遍歷所有 activation 值 act_row[a]
* 3. 計算 signed_w * act_row[a],並做飽和處理
* 4. 存入 table_[w * padded_a_range_ + a]
*
* 這樣,後續 matmul 可直接對每個 weight lookup 整列乘積向量,
* 而不需即時計算乘法。
*/
void fill_from_activation(const ActivationType* act_row) {
for (std::size_t w = 0; w < weight_levels_; ++w) {
int signed_w = static_cast<int>(w < (weight_levels_/2) ? w : w - weight_levels_);
P* row_ptr = &table_[w * padded_a_range_];
for (std::size_t a = 0; a < a_range_; ++a) {
// 讀取 activation 值;若是 uint8_t,則認為是 quantized 4-bit 需做 signed 轉換
P act_val;
if constexpr (std::is_same<ActivationType, uint8_t>::value) {
uint8_t raw = act_row[a];
// two's complement mapping for 4-bit
act_val = static_cast<P>( raw < (weight_levels_/2) ? raw : raw - weight_levels_ );
} else {
act_val = static_cast<P>(act_row[a]);
}
// 計算乘積

int64_t prod_int = static_cast<int64_t>(signed_w) * static_cast<int64_t>(act_val);
// 飽和處理
if (prod_int > std::numeric_limits<P>::max()) prod_int = std::numeric_limits<P>::max();
else if (prod_int < std::numeric_limits<P>::min()) prod_int = std::numeric_limits<P>::min();
row_ptr[a] = static_cast<P>(prod_int);
void fill_from_activation(const ActivationType* act_row) noexcept {
// Refill LUT with actual activation values
fill_impl([&](std::size_t a) -> int64_t {
int64_t raw = static_cast<int64_t>(act_row[a]);
if constexpr (std::is_same<ActivationType, uint8_t>::value) {
// two's-complement mapping for 4-bit
raw = (raw < static_cast<int64_t>(weight_levels_ / 2))
? raw : (raw - static_cast<int64_t>(weight_levels_));
}
// padding 部分不必初始化
}
return raw;
});
}

// 取得 weight-level w 的向量查表指標
inline const P* get_row(std::size_t w) const noexcept {
return &table_[w * padded_a_range_];
}

// 元素查表,保留舊接口
inline ProductType get(std::size_t w, std::size_t a) const noexcept {
return table_[w * padded_a_range_ + a];
}
inline ProductType operator()(std::size_t w, std::size_t a) const noexcept {
return get(w, a);
}

// table 資料屬性
const P* data() const noexcept { return table_.data(); }
std::size_t row_stride() const noexcept { return padded_a_range_; }
std::size_t weight_levels() const noexcept { return weight_levels_; }
Expand All @@ -134,5 +103,29 @@ class ProductLookupTable {

private:
std::size_t weight_levels_, a_range_, padded_a_range_;
std::vector<P, AlignedAllocator<P, 64>> table_; // flat buffer (w * padded + a)
std::vector<P, AlignedAllocator<P, 64>> table_;

// Multiply and saturate in product type range
static P compute_prod(int64_t signed_w, int64_t act_val) noexcept {
int64_t prod = signed_w * act_val;
if (prod > std::numeric_limits<P>::max()) prod = std::numeric_limits<P>::max();
else if (prod < std::numeric_limits<P>::min()) prod = std::numeric_limits<P>::min();
return static_cast<P>(prod);
}

// Core fill logic: computes table entries by combining signed weight and activation
template<typename GetAct>
void fill_impl(GetAct get_act) noexcept {
for (std::size_t w = 0; w < weight_levels_; ++w) {
int64_t signed_w = (w < weight_levels_ / 2)
? static_cast<int64_t>(w)
: static_cast<int64_t>(w) - static_cast<int64_t>(weight_levels_);
P* row_ptr = &table_[w * padded_a_range_];
for (std::size_t a = 0; a < a_range_; ++a) {
int64_t act = get_act(a);
row_ptr[a] = compute_prod(signed_w, act);
}
// padding intentionally left uninitialized for performance
}
}
};
5 changes: 0 additions & 5 deletions src/matrix_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,13 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4)
Matrix<T, RowMajor, PlainStorage<T>> C(M, N);
std::vector<std::thread> threads;

// 計算每個執行緒處理的行數
size_t rows_per_thread = (M + num_threads - 1) / num_threads;

// 為每個執行緒分配工作
for (size_t t = 0; t < num_threads; ++t) {
threads.emplace_back([&, t]() {
size_t start_row = t * rows_per_thread;
size_t end_row = std::min(start_row + rows_per_thread, M);

// 為每個執行緒創建局部結果矩陣
Matrix<T, RowMajor, PlainStorage<T>> local_C(M, N);

for (size_t i = start_row; i < end_row; ++i) {
Expand All @@ -67,7 +64,6 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4)
}
}

// 合併結果
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
for (size_t i = start_row; i < end_row; ++i) {
Expand All @@ -78,7 +74,6 @@ auto matmul(const MA& A, const MB& B, size_t num_threads = 4)
});
}

// 等待所有執行緒完成
for (auto& thread : threads) {
thread.join();
}
Expand Down
3 changes: 1 addition & 2 deletions src/post_processing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
#include <cmath>
#include "matrix.hpp"

/// 後處理可選激活函式

enum class Activation {
Linear,
ReLU,
Sigmoid,
Tanh
};

/// 1) bias 加法:對於每一列,將 bias[j] 加到 M(i,j)
template<typename T, typename Layout, typename Storage>
Matrix<T,Layout,Storage> add_bias(
const Matrix<T,Layout,Storage>& M,
Expand Down
12 changes: 5 additions & 7 deletions src/quant_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,24 @@
#include <cmath>
#include <cstdint>

// INT4 量化:fp16_val → uint8_t (低 4 bits)
inline uint8_t quantize_int4(float fp16_val, float scale, int zero_point = 8) {
// 先將值限制在有符號 int4 範圍內(考慮 scale)

float min_val = -8.0f * scale;
float max_val = 7.0f * scale;
float clamped_val = std::clamp(fp16_val, min_val, max_val);

// 將值轉換為整數,並加上 zero_point

float scaled_val = clamped_val / scale;
int q = static_cast<int>(std::round(scaled_val)) + zero_point;

// 確保結果在 0-15 範圍內

q = std::clamp(q, 0, 15);
return static_cast<uint8_t>(q);
}

// INT4 反量化:uint8_t (低 4 bits) → float

inline float dequantize_int4(uint8_t q, float scale, int zero_point = 8) {
// 將無符號值轉換為有符號值

int qi = static_cast<int>(q) - zero_point;
// 轉換回浮點數
return static_cast<float>(qi) * scale;
}
Loading