From 3f5abe73e303d5c0c0b07261eafbb44138e873bd Mon Sep 17 00:00:00 2001 From: Wayne Date: Sun, 4 May 2025 14:21:47 +0800 Subject: [PATCH 1/2] remove redundant comment --- src/gemm_engine.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gemm_engine.hpp b/src/gemm_engine.hpp index 1109c73..93dbef2 100644 --- a/src/gemm_engine.hpp +++ b/src/gemm_engine.hpp @@ -1,4 +1,3 @@ -// src/gemm_engine.hpp #pragma once #include #include From 664fd1e320c363884e94673515c7cf94a807e22e Mon Sep 17 00:00:00 2001 From: Wayne Date: Sun, 4 May 2025 14:51:10 +0800 Subject: [PATCH 2/2] feature:add accuracy and error analysis component --- Makefile | 2 +- README.md | 2 +- scripts/example.py | 49 +++++++++++---------- src/accuracy_utils.hpp | 24 ++++++++++ src/bindings.cpp | 23 ++++++++++ tests/{test_post_process.py => test_api.py} | 7 +++ tests/test_correctness.cpp | 19 +++++++- 7 files changed, 100 insertions(+), 26 deletions(-) create mode 100644 src/accuracy_utils.hpp rename tests/{test_post_process.py => test_api.py} (70%) diff --git a/Makefile b/Makefile index a36042f..27e6671 100644 --- a/Makefile +++ b/Makefile @@ -57,7 +57,7 @@ mpgemm$(PYEXT): src/bindings.cpp $(HEADERS) # run pytest pytest: all - PYTHONPATH=. python3 -m pytest -q tests/test_post_process.py + PYTHONPATH=. python3 -m pytest -q tests/test_api.py run: all ./$(TARGET_MAIN) diff --git a/README.md b/README.md index d865e74..851f539 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ mpGEMM/ │ └── bindings.cpp ├── tests/ │ ├── test_correctness.cpp -│ ├── test_post_process.py +│ ├── test_api.py │ └── run_benchmark.cpp ├── scripts/ │ └── benchmark.py diff --git a/scripts/example.py b/scripts/example.py index 437be5f..61f70e2 100644 --- a/scripts/example.py +++ b/scripts/example.py @@ -2,52 +2,55 @@ import os import sys -# Ensure project root is on PYTHONPATH +# 确保项目根目录在 PYTHONPATH sys.path.insert(0, os.path.abspath(os.path.join(__file__, os.pardir, '..'))) import numpy as np import mpgemm from mpgemm import Activation - def main(): - # Matrix dimensions + # 矩阵维度 M, K, N = 128, 128, 128 - # Random number generator - rng = np.random.default_rng(42) + # 随机数生成器 + rng = np.random.default_rng(2025) - # Generate random quantized weights (INT4 range 0-15) + # 生成随机 INT4 权重(0-15) weights = rng.integers(0, 16, size=(M, K), dtype=np.uint8) - # Generate random FP16 activations + # 生成随机 FP16 激活 activations = rng.standard_normal(size=(K, N)).astype(np.float16) - # Generate random bias (FP32) + # 随机 bias(FP32) bias = rng.standard_normal(size=N).astype(np.float32) - # Initialize GEMM engine (LUT backend) - gemm = mpgemm.Engine("lut") - gemm.generate_lut(bit_width=4) - - # Flatten inputs to Python lists + # 扁平化并转为 Python 列表 w_flat = weights.flatten().tolist() - # cast activations to Python float list a_flat = activations.flatten().astype(float).tolist() bias_list = bias.tolist() - # Perform matrix multiplication - out_flat = gemm.matmul(w_flat, a_flat, M, K, N) + # === 1. 基准参考输出 === + gemm_ref = mpgemm.Engine("naive") + ref_flat = gemm_ref.matmul(w_flat, a_flat, M, K, N) - # Post-processing - out_biased = gemm.add_bias(out_flat, M, N, bias_list) - out_relu = gemm.apply_activation(out_biased, M, N, Activation.ReLU) + # === 2. LUT 后端输出 === + gemm_lut = mpgemm.Engine("lut") + gemm_lut.generate_lut(bit_width=4) + out_flat = gemm_lut.matmul(w_flat, a_flat, M, K, N) - # Reshape back to matrix - output = np.array(out_relu, dtype=np.float32).reshape(M, N) + # === 3. 后处理示例 === + out_biased = gemm_lut.add_bias(out_flat, M, N, bias_list) + out_relu = gemm_lut.apply_activation(out_biased, M, N, Activation.ReLU) - # Display results + # 还原成矩阵 + output = np.array(out_relu, dtype=np.float32).reshape(M, N) print(f"Output shape: {output.shape}") - print("Sample output [0,:5]:", output[0, :5]) + print("Sample row[0, :5]:", output[0, :5]) + # === 4. 误差分析示例 === + stats = mpgemm.measure_error(ref_flat, out_flat) + print(f"\nError relative to naive:") + print(f" MSE = {stats['mse']:.6f}") + print(f" Max error = {stats['max_error']:.6f}") if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/accuracy_utils.hpp b/src/accuracy_utils.hpp new file mode 100644 index 0000000..a918d08 --- /dev/null +++ b/src/accuracy_utils.hpp @@ -0,0 +1,24 @@ +#pragma once +#include +#include +#include +#include + +struct ErrorStats { + double mse; + double max_error; +}; + +// Computes MSE and max absolute error between two same-sized flat arrays. +inline ErrorStats measure_error(const std::vector& ref, + const std::vector& test) { + size_t N = ref.size(); + double sum_sq = 0.0; + double max_err = 0.0; + for (size_t i = 0; i < N; ++i) { + double diff = double(test[i]) - double(ref[i]); + sum_sq += diff * diff; + max_err = std::max(max_err, std::abs(diff)); + } + return { sum_sq / N, max_err }; +} diff --git a/src/bindings.cpp b/src/bindings.cpp index e625078..af578dc 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -5,6 +5,7 @@ #include "lut_utils.hpp" #include "post_processing.hpp" #include "gemm_engine.hpp" +#include "accuracy_utils.hpp" namespace py = pybind11; @@ -71,4 +72,26 @@ PYBIND11_MODULE(mpgemm, m) { .def("apply_activation", &Engine::apply_activation, "Apply activation to GEMM output", py::arg("C"), py::arg("M"), py::arg("N"), py::arg("act")); + + // --- Error measurement --- + py::class_(m, "ErrorStats") + .def_readonly("mse", &ErrorStats::mse) + .def_readonly("max_error", &ErrorStats::max_error); + m.def("measure_error", + [](const std::vector& ref, + const std::vector& test) { + auto s = measure_error(ref, test); + py::dict d; + d["mse"] = s.mse; + d["max_error"] = s.max_error; + return d; + }, + py::arg("reference"), + py::arg("test"), + R"( + Compute error statistics between two flat float lists: + - mse: mean squared error + - max_error: maximum absolute error + Returns a dict: {\"mse\": ..., \"max_error\": ...} + )"); } \ No newline at end of file diff --git a/tests/test_post_process.py b/tests/test_api.py similarity index 70% rename from tests/test_post_process.py rename to tests/test_api.py index 2f23ac6..0ed15bb 100644 --- a/tests/test_post_process.py +++ b/tests/test_api.py @@ -13,3 +13,10 @@ def test_relu(): R = mpgemm.apply_activation(M.flatten().tolist(), 2, 2, mpgemm.Activation.ReLU) R = np.array(R).reshape(2,2) assert np.all(R >= 0) + +def test_measure_error(): + ref = [1.0, 2.0, 3.0] + test = [1.1, 1.9, 2.5] + stats = mpgemm.measure_error(ref, test) + assert abs(stats["mse"] - 0.09) < 1e-6 + assert abs(stats["max_error"] - 0.5) < 1e-6 \ No newline at end of file diff --git a/tests/test_correctness.cpp b/tests/test_correctness.cpp index a8ead0c..233f745 100644 --- a/tests/test_correctness.cpp +++ b/tests/test_correctness.cpp @@ -5,12 +5,14 @@ #include "../src/lut_utils.hpp" #include "../src/quant_utils.hpp" #include "../src/post_processing.hpp" +#include "../src/accuracy_utils.hpp" #include #include #include #include #include +#include // Helper: compare two matrices for equality template @@ -376,10 +378,24 @@ bool run_linear_test() { return pass; } +// 14. accuracy test +bool run_accuracy_test() { + std::cout << "Running accuracy test...\n"; + std::vector A = {1.0f, 2.0f, 3.0f}; + std::vector B = {1.1f, 1.9f, 2.5f}; + auto stats = measure_error(A, B); + // manual: + // diffs = {0.1, -0.1, -0.5}, sq = {0.01,0.01,0.25}, mse = 0.27/3 = 0.09 + assert(std::fabs(stats.mse - 0.09) < 1e-6); + assert(std::fabs(stats.max_error - 0.5) < 1e-6); + std::cout << "Accuracy test PASS\n"; + return true; +} + int main() { int passed=0; - int total=15; + int total=16; if (run_basic_test()) ++passed; if (run_negative_test()) ++passed; if (run_non_square_test()) ++passed; @@ -395,6 +411,7 @@ int main() { if (run_sigmoid_test()) ++passed; if (run_tanh_test()) ++passed; if (run_linear_test()) ++passed; + if (run_accuracy_test()) ++passed; #ifdef USE_MKL ++total; // 只有啟用 MKL 才加總數 if (run_mkl_test()) ++passed;