Skip to content

Commit 510dc0a

Browse files
authored
Add an op to dequantize NF4 tensors (#3371)
1 parent 0379891 commit 510dc0a

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,15 @@ at::Tensor woq_linear_mul_forward(
766766
op_context.data_ptr<int64_t>()[0])
767767
->run_binary(input, "mul", others);
768768
}
769+
770+
IPEX_DEFINE_DISPATCH(dequant_nf4_stub);
771+
at::Tensor dequantize_nf4(
772+
const at::Tensor& t,
773+
const at::Tensor& scales,
774+
int64_t group_size,
775+
c10::ScalarType out_dtype) {
776+
return dequant_nf4_stub(kCPU, t, scales, group_size, out_dtype);
777+
}
769778
#endif
770779

771780
} // namespace cpu
@@ -1068,6 +1077,10 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
10681077
"woq_linear",
10691078
c10::DispatchKey::AutocastCPU,
10701079
torch_ipex::autocast::woq_linear_forward_v2);
1080+
m.def(
1081+
"dequantize_nf4(Tensor t, Tensor scales, int group_size, ScalarType out_dtype) -> Tensor");
1082+
m.impl(
1083+
"dequantize_nf4", c10::DispatchKey::CPU, torch_ipex::cpu::dequantize_nf4);
10711084
#endif
10721085
// fuse eltwise
10731086
m.def(

csrc/cpu/aten/Linear.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,19 @@ using woq_dequant_int4_to_int8_packed_fn = at::Tensor (*)(
247247
int64_t,
248248
at::Tensor&);
249249

250+
using dequant_nf4_fn = at::Tensor (*)(
251+
const at::Tensor&,
252+
const at::Tensor&,
253+
int64_t,
254+
c10::ScalarType);
255+
250256
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_kernel_fn, woq_tpp_gemm_kernel_stub);
251257
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_packB_fn, woq_tpp_gemm_packB_stub);
252258
IPEX_DECLARE_DISPATCH(woq_tpp_gemm_unpackB_fn, woq_tpp_gemm_unpackB_stub);
253259
IPEX_DECLARE_DISPATCH(
254260
woq_dequant_int4_to_int8_packed_fn,
255261
woq_dequant_int4_to_int8_packed_stub);
262+
IPEX_DECLARE_DISPATCH(dequant_nf4_fn, dequant_nf4_stub);
256263

257264
// Fusion types
258265
#define WOQ_FUSE_NONE 0x0

csrc/cpu/aten/kernels/WoqTppKrnl.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "aten/utils/woq.h"
99
#include "csrc/cpu/tpp/kernels/TPPGEMMKrnl.h"
1010
#include "csrc/cpu/tpp/woq/tla.h"
11+
#include "csrc/cpu/vec/vec.h"
1112

1213
#ifdef __GNUC__
1314
#include <features.h>
@@ -4981,6 +4982,122 @@ at::Tensor dequantize_int4_weight_to_int8_packed(
49814982
return dqw;
49824983
}
49834984

4985+
at::Tensor dequantize_nf4(
4986+
const at::Tensor& t,
4987+
const at::Tensor& scales,
4988+
int64_t group_size,
4989+
c10::ScalarType out_dtype) {
4990+
TORCH_CHECK(
4991+
t.dim() == 2,
4992+
"dequantize_nf4 only supports 2D input, but got ",
4993+
t.dim(),
4994+
"D");
4995+
TORCH_CHECK(
4996+
t.scalar_type() == c10::kByte || t.scalar_type() == c10::kChar,
4997+
"dequantize_nf4 only supports uint8 or int8 input, but got ",
4998+
t.scalar_type());
4999+
TORCH_CHECK(
5000+
scales.dim() <= 2,
5001+
"dequantize_nf4: scales must be 1D (per-channel) or 2D (per-group), but got ",
5002+
scales.dim(),
5003+
"D");
5004+
TORCH_CHECK(
5005+
out_dtype == c10::kFloat || out_dtype == c10::kBFloat16 ||
5006+
out_dtype == c10::kHalf,
5007+
"dequantize_nf4 only supports float, bfloat16 or float16 output, but got ",
5008+
out_dtype);
5009+
auto N = t.size(0);
5010+
auto K = t.size(1) * 2;
5011+
using Tcomp = float;
5012+
constexpr auto VEC_LEN = sizeof(__m512i) / sizeof(Tcomp);
5013+
if (K % VEC_LEN == 0 && (group_size >= VEC_LEN || group_size < 0)) {
5014+
auto n_groups = K / group_size;
5015+
torch::Tensor out = torch::empty({N, K}, out_dtype);
5016+
product_dispatcher<
5017+
std::tuple<c10::ScalarType, c10::ScalarType>,
5018+
std::tuple<
5019+
enumerate_dispatcher<
5020+
c10::ScalarType,
5021+
c10::kFloat,
5022+
c10::kBFloat16,
5023+
c10::kHalf>,
5024+
enumerate_dispatcher<
5025+
c10::ScalarType,
5026+
c10::kFloat,
5027+
c10::kBFloat16,
5028+
c10::kHalf>>>::
5029+
call(
5030+
std::make_tuple(out_dtype, scales.scalar_type()),
5031+
[&](auto tuple) {
5032+
// Note: we don't use the `Dequantize` utilities because they are
5033+
// designed for packed layout but we are handling plain layout
5034+
// here.
5035+
auto out_dtype_ = std::get<0>(tuple);
5036+
auto scales_dtype_ = std::get<1>(tuple);
5037+
using T =
5038+
typename c10::impl::ScalarTypeToCPPType<out_dtype_>::type;
5039+
using Tscale =
5040+
typename c10::impl::ScalarTypeToCPPType<scales_dtype_>::type;
5041+
using VT = typename VecType<Tcomp>::type;
5042+
using V = VecOps<VT>;
5043+
VT lut = V::set_nf4_lut();
5044+
constexpr auto k_step = VEC_LEN / 2;
5045+
auto pt = GetVLAPtr<uint8_t>(
5046+
(uint8_t*)t.data_ptr(), {t.size(1) / k_step, k_step});
5047+
auto pscales = group_size < 0
5048+
? GetVLAPtr<Tscale>(scales, {1})
5049+
: GetVLAPtr<Tscale>(scales, {n_groups});
5050+
auto out_ptr = GetVLAPtr<T>(out, {K / VEC_LEN, VEC_LEN});
5051+
5052+
auto dequant_loop = ThreadedLoop<2>({{N}}, /* loop_scheme */ "A");
5053+
dequant_loop(
5054+
[&](int* idx) {
5055+
int n = idx[0];
5056+
for (int k = 0; k < t.size(1); k += k_step) {
5057+
// Load 64 bits of nf4 data and a single scale data
5058+
auto p = pt[n][k / k_step];
5059+
auto scale_idx = group_size < 0 ? 0 : k * 2 / group_size;
5060+
auto vscales = V::set1((float)pscales[n][scale_idx]);
5061+
uint64_t packed = reinterpret_cast<uint64_t*>(p)[0];
5062+
// unpack nf4 data to 32-bit integers
5063+
uint64_t high = 0;
5064+
uint64_t low = 0;
5065+
for (int i = 0; i < 8; ++i) {
5066+
low |= ((packed >> (i * 4)) & 0xf) << (i * 8);
5067+
high |= ((packed >> (i * 4 + 32)) & 0xf) << (i * 8);
5068+
}
5069+
__m128i packed_128 = _mm_set_epi64x(high, low);
5070+
__m512i vint32 = _mm512_cvtepu8_epi32(packed_128);
5071+
// Table look-up
5072+
__m512 vout = _mm512_permutexvar_ps(vint32, lut);
5073+
// Apply scale
5074+
vout = V::mul(vout, vscales);
5075+
// Store results
5076+
auto pout = out_ptr[n][k / k_step];
5077+
if constexpr (std::is_same<T, float>()) {
5078+
_mm512_storeu_ps(pout, vout);
5079+
} else if constexpr (std::is_same<T, at::BFloat16>()) {
5080+
_mm256_storeu_si256(
5081+
(__m256i*)pout, cvt_fp32_to_bf16(vout));
5082+
} else if constexpr (std::is_same<T, at::Half>()) {
5083+
_mm256_storeu_si256(
5084+
(__m256i*)pout, cvt_fp32_to_fp16(vout));
5085+
} else {
5086+
TORCH_CHECK(false, "Unexpected dtype");
5087+
}
5088+
}
5089+
},
5090+
[&]() {},
5091+
[&]() {});
5092+
},
5093+
[&](auto out_dtype_) { failing_fallback(); });
5094+
return out;
5095+
}
5096+
auto out = dequantize_woq_weight(
5097+
t, {N, K}, scales.unsqueeze(-1), at::Tensor(), WOQ_DTYPE_NF4, group_size);
5098+
return out.to(out_dtype);
5099+
}
5100+
49845101
#else // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
49855102

49865103
#define SMALL_BATCH_THRESHOLD 32
@@ -5087,6 +5204,19 @@ at::Tensor dequantize_int4_weight_to_int8_packed(
50875204
at::Tensor& compensation) {
50885205
return qw_packed;
50895206
}
5207+
5208+
at::Tensor dequantize_nf4(
5209+
const at::Tensor& t,
5210+
const at::Tensor& scales,
5211+
int64_t group_size,
5212+
c10::ScalarType out_dtype) {
5213+
auto N = t.size(0);
5214+
auto K = t.size(1) * 2;
5215+
auto out = dequantize_woq_weight(
5216+
t, {N, K}, scales.unsqueeze(-1), at::Tensor(), WOQ_DTYPE_NF4, group_size);
5217+
return out.to(out_dtype);
5218+
}
5219+
50905220
#endif // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
50915221

50925222
} // namespace
@@ -5097,6 +5227,7 @@ IPEX_REGISTER_DISPATCH(woq_tpp_gemm_unpackB_stub, &qlinear_woq_unpack);
50975227
IPEX_REGISTER_DISPATCH(
50985228
woq_dequant_int4_to_int8_packed_stub,
50995229
&dequantize_int4_weight_to_int8_packed);
5230+
IPEX_REGISTER_DISPATCH(dequant_nf4_stub, &dequantize_nf4);
51005231

51015232
} // namespace cpu
51025233
} // namespace torch_ipex

tests/cpu/test_quantization_default_recipe.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,41 @@ def test(M, has_bias, w_dtype):
23192319
test(M, use_bias, w_dtype)
23202320

23212321

2322+
class QuantizedOpTester(TestCase):
2323+
def test_dequantize_nf4(self):
2324+
dtype_list = [torch.float, torch.bfloat16, torch.half]
2325+
group_size_list = [-1, 32, 128]
2326+
cases = itertools.product(dtype_list, group_size_list)
2327+
for dtype, group_size in cases:
2328+
t_fp = torch.randn(1024, 1024, dtype=dtype)
2329+
scale_dtype_list = list(set([dtype, torch.float]))
2330+
for scale_dtype in scale_dtype_list:
2331+
if group_size < 0:
2332+
t, scales, zp = quantize_per_channel(
2333+
t_fp, WoqWeightDtype.NF4, None, None, sym_quant=True
2334+
)
2335+
scales = scales.to(scale_dtype)
2336+
out_ref = dequantize_per_channel(
2337+
t, scales, zp, WoqWeightDtype.NF4, t_fp.shape
2338+
).to(dtype)
2339+
out = torch.ops.torch_ipex.dequantize_nf4(
2340+
t, scales, group_size, dtype
2341+
)
2342+
assert torch.allclose(out, out_ref)
2343+
else:
2344+
t, scales, zp = quantize_per_block(
2345+
t_fp, WoqWeightDtype.NF4, group_size, None, None, sym_quant=True
2346+
)
2347+
scales = scales.to(scale_dtype)
2348+
out_ref = dequantize_per_block(
2349+
t, scales, zp, WoqWeightDtype.NF4, group_size, t_fp.shape
2350+
).to(dtype)
2351+
out = torch.ops.torch_ipex.dequantize_nf4(
2352+
t, scales, group_size, dtype
2353+
)
2354+
assert torch.allclose(out, out_ref)
2355+
2356+
23222357
if __name__ == "__main__":
23232358
test = unittest.main()
23242359
run_tests()

0 commit comments

Comments
 (0)