Skip to content

Commit 0379891

Browse files
authored
WOQ INT8: pad N if it is not a multiple of block_n (#3370)
* WOQ INT8: pad N if it is not a multiple of block_n * Update UT * Fix unpack * Fix unpack issue with padded_N
1 parent 936d11b commit 0379891

File tree

5 files changed

+79
-20
lines changed

5 files changed

+79
-20
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ at::Tensor woq_linear_pack_weight(
404404
kCPU, weight_int4, weight_dtype, block_n, block_k, lowp_mode);
405405
}
406406
if (N % block_n) {
407-
return weight;
407+
at::Tensor weight_padded =
408+
at::pad(weight, {0, 0, 0, block_n - N % block_n}, "constant", 0);
409+
return woq_tpp_gemm_packB_stub(
410+
kCPU, weight_padded, weight_dtype, block_n, block_k, lowp_mode);
408411
} else {
409412
return woq_tpp_gemm_packB_stub(
410413
kCPU, weight, weight_dtype, block_n, block_k, lowp_mode);

csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -725,24 +725,16 @@ at::Tensor unpack(ContextLinearWoq& context, const at::Tensor& tensor) {
725725
unpacked_weight = woq_shuffle_weight_back_by_group_idx(
726726
unpacked_weight, context.weight_shape_, g_idx, group_size);
727727
}
728-
if (tensor.dim() > 2) {
729-
auto scales = context.scales_list_[0];
730-
auto zero_points = context.zero_points_list_[0];
731-
if (context.is_4bit_) {
732-
auto unpacked_shape = unpacked_weight.sizes().vec(); // = N * K/2
733-
auto shape = context.weight_shape_;
734-
shape.back() /= 2;
735-
at::Tensor qweight =
736-
at::empty(shape, device(c10::kCPU).dtype(c10::kByte));
737-
assert(qweight.numel() % 2 == 0);
738-
std::memcpy(
739-
qweight.data_ptr(), unpacked_weight.data_ptr(), qweight.numel());
740-
return qweight;
741-
} else { // int8
742-
return unpacked_weight;
743-
}
728+
auto shape = context.weight_shape_;
729+
if (context.is_4bit_) {
730+
shape.back() = (shape.back() + 1) / 2;
744731
}
745-
return unpacked_weight;
732+
// weight may be padded. Copy data according to original shape
733+
at::Tensor qweight =
734+
at::empty(shape, device(c10::kCPU).dtype(unpacked_weight.scalar_type()));
735+
assert(qweight.numel() % 2 == 0);
736+
std::memcpy(qweight.data_ptr(), unpacked_weight.data_ptr(), qweight.numel());
737+
return qweight;
746738
}
747739

748740
template <typename T, typename Tg, bool is_4bit = false>

csrc/cpu/jit/cpu/kernels/OpContext.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,12 @@ at::Tensor IpexWoqLinearOpContext::get_at_packed_weight() {
427427
}
428428

429429
c10::optional<at::Tensor> IpexWoqLinearOpContext::get_at_bias() {
430+
if (op_context_.at_bias_.has_value()) {
431+
auto b = op_context_.at_bias_.value();
432+
if (b.size(0) > op_context_.weight_shape_[0]) {
433+
return c10::make_optional(b.narrow(0, 0, op_context_.weight_shape_[0]));
434+
}
435+
}
430436
return op_context_.at_bias_;
431437
}
432438

csrc/cpu/jit/cpu/kernels/OpContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class WoqLinearOpContext : public torch::jit::CustomClassHolder {
383383
auto orig_weight_ = this->to_public(this->get_at_packed_weight());
384384
auto weight_dtype_ = this->get_context().weight_dtype_;
385385
auto weight_shape_ = this->get_weight_shape();
386-
auto orig_bias_ = this->get_context().at_bias_;
386+
auto orig_bias_ = this->get_at_bias();
387387
auto scales = this->get_scales();
388388
auto zero_points = this->get_zero_points();
389389
auto g_idx = this->get_g_idx();

tests/cpu/test_quantization_default_recipe.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def test(feature, has_bias, w_dtype):
989989
shape_list = [
990990
[3, 31, 31],
991991
[4, 4096, 4096],
992-
[9, 4095, 4095],
992+
[4, 4096, 4080],
993993
[196, 4095, 16383],
994994
[1024, 512, 512],
995995
]
@@ -2260,6 +2260,64 @@ def test(feature, has_bias, w_dtype, lowp_mode, enable_amp):
22602260
for shape, use_bias, w_dtype, lowp_mode, enable_amp in cases:
22612261
test(shape, use_bias, w_dtype, lowp_mode, enable_amp)
22622262

2263+
def test_weight_padding(self):
2264+
"""
2265+
If N of weight shape N * K is not a multiple of block_n, it is padded to be a multiple of block_n.
2266+
"""
2267+
2268+
class Mod(nn.Module):
2269+
def __init__(self, input_channel, output_channel, has_bias):
2270+
super(Mod, self).__init__()
2271+
self.linear = torch.nn.Linear(input_channel, output_channel, has_bias)
2272+
2273+
def forward(self, x):
2274+
return self.linear(x)
2275+
2276+
def test(M, has_bias, w_dtype):
2277+
N, K, N_padded = 500, 512, 512
2278+
model = Mod(K, N, has_bias)
2279+
m = model.eval()
2280+
m_ref = Mod(K, N_padded, False).eval()
2281+
data = torch.rand(M, K)
2282+
weight = model.linear.weight
2283+
weight_int4, w_scales, w_zero_points = quantize_per_channel(
2284+
weight,
2285+
w_dtype,
2286+
sym_quant=True if w_dtype == WoqWeightDtype.NF4 else False,
2287+
)
2288+
weight_fp32 = dequantize_per_channel(
2289+
weight_int4, w_scales, w_zero_points, w_dtype, weight.shape
2290+
)
2291+
if has_bias:
2292+
bias = model.linear.bias
2293+
output1 = torch.matmul(data, weight_fp32.T) + bias
2294+
else:
2295+
output1 = torch.matmul(data, weight_fp32.T)
2296+
2297+
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
2298+
weight_dtype=w_dtype
2299+
)
2300+
prepared_model = prepare(m, qconfig, example_inputs=data, inplace=False)
2301+
prepared_model_ref = prepare(
2302+
m_ref, qconfig, example_inputs=data, inplace=False
2303+
)
2304+
with torch.no_grad():
2305+
woq_model = convert(prepared_model)
2306+
woq_model_ref = convert(prepared_model_ref)
2307+
assert (
2308+
woq_model.linear.weight.shape == woq_model_ref.linear.weight.shape
2309+
)
2310+
2311+
output2 = woq_model(data)
2312+
torch.testing.assert_close(output1, output2)
2313+
2314+
M_list = [4, 1024]
2315+
use_bias_list = [True, False]
2316+
w_dtype_list = [WoqWeightDtype.INT8, WoqWeightDtype.INT4, WoqWeightDtype.NF4]
2317+
cases = itertools.product(M_list, use_bias_list, w_dtype_list)
2318+
for M, use_bias, w_dtype in cases:
2319+
test(M, use_bias, w_dtype)
2320+
22632321

22642322
if __name__ == "__main__":
22652323
test = unittest.main()

0 commit comments

Comments
 (0)