Skip to content

Commit 16bceb5

Browse files
authored
WOQ: Fuse compute_int8_qparams_per_block and quantize_per_block (#3390)
1 parent 2effbbb commit 16bceb5

File tree

6 files changed

+1777
-1195
lines changed

6 files changed

+1777
-1195
lines changed

csrc/cpu/aten/kernels/WoqInt8GemmAPerKBlockKrnl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,9 @@ at::Tensor woq_gemm_int8(
105105
if (quant_block_k <= 0)
106106
quant_block_k = block_k;
107107
bool is_sym_quant = !is_asymmetric_quant_a(quant_a_mode);
108-
auto [scale_a, zp_a] = compute_int8_qparams_per_block<act_type>(
109-
x, quant_block_k, quant_a_mode, is_sym_quant);
110-
auto x_quantized = quantize_per_block<act_type>(
111-
x, scale_a, zp_a, quant_block_k, quant_a_mode, is_sym_quant);
108+
auto [x_quantized, scale_a, zp_a] =
109+
dynamic_quantize_per_block<act_type>(
110+
x, quant_block_k, quant_a_mode);
112111
float* scale_a_ptr = (float*)scale_a.data_ptr();
113112
int32_t* zp_a_ptr =
114113
is_sym_quant ? nullptr : (int32_t*)zp_a.data_ptr();

csrc/cpu/aten/kernels/WoqInt8GemmAPerMKBlockKrnl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,9 @@ at::Tensor woq_gemm_int8(
105105
if (quant_block_k <= 0)
106106
quant_block_k = block_k;
107107
bool is_sym_quant = !is_asymmetric_quant_a(quant_a_mode);
108-
auto [scale_a, zp_a] = compute_int8_qparams_per_block<act_type>(
109-
x, quant_block_k, quant_a_mode, is_sym_quant);
110-
auto x_quantized = quantize_per_block<act_type>(
111-
x, scale_a, zp_a, quant_block_k, quant_a_mode, is_sym_quant);
108+
auto [x_quantized, scale_a, zp_a] =
109+
dynamic_quantize_per_block<act_type>(
110+
x, quant_block_k, quant_a_mode);
112111
float* scale_a_ptr = (float*)scale_a.data_ptr();
113112
int32_t* zp_a_ptr =
114113
is_sym_quant ? nullptr : (int32_t*)zp_a.data_ptr();

csrc/cpu/aten/kernels/WoqInt8GemmAPerMKrnl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,9 @@ at::Tensor woq_gemm_int8(
105105
if (quant_block_k <= 0)
106106
quant_block_k = block_k;
107107
bool is_sym_quant = !is_asymmetric_quant_a(quant_a_mode);
108-
auto [scale_a, zp_a] = compute_int8_qparams_per_block<act_type>(
109-
x, quant_block_k, quant_a_mode, is_sym_quant);
110-
auto x_quantized = quantize_per_block<act_type>(
111-
x, scale_a, zp_a, quant_block_k, quant_a_mode, is_sym_quant);
108+
auto [x_quantized, scale_a, zp_a] =
109+
dynamic_quantize_per_block<act_type>(
110+
x, quant_block_k, quant_a_mode);
112111
float* scale_a_ptr = (float*)scale_a.data_ptr();
113112
int32_t* zp_a_ptr =
114113
is_sym_quant ? nullptr : (int32_t*)zp_a.data_ptr();

0 commit comments

Comments
 (0)