88#include " aten/utils/woq.h"
99#include " csrc/cpu/tpp/kernels/TPPGEMMKrnl.h"
1010#include " csrc/cpu/tpp/woq/tla.h"
11+ #include " csrc/cpu/vec/vec.h"
1112
1213#ifdef __GNUC__
1314#include < features.h>
@@ -4981,6 +4982,122 @@ at::Tensor dequantize_int4_weight_to_int8_packed(
49814982 return dqw;
49824983}
49834984
4985+ at::Tensor dequantize_nf4 (
4986+ const at::Tensor& t,
4987+ const at::Tensor& scales,
4988+ int64_t group_size,
4989+ c10::ScalarType out_dtype) {
4990+ TORCH_CHECK (
4991+ t.dim () == 2 ,
4992+ " dequantize_nf4 only supports 2D input, but got " ,
4993+ t.dim (),
4994+ " D" );
4995+ TORCH_CHECK (
4996+ t.scalar_type () == c10::kByte || t.scalar_type () == c10::kChar ,
4997+ " dequantize_nf4 only supports uint8 or int8 input, but got " ,
4998+ t.scalar_type ());
4999+ TORCH_CHECK (
5000+ scales.dim () <= 2 ,
5001+ " dequantize_nf4: scales must be 1D (per-channel) or 2D (per-group), but got " ,
5002+ scales.dim (),
5003+ " D" );
5004+ TORCH_CHECK (
5005+ out_dtype == c10::kFloat || out_dtype == c10::kBFloat16 ||
5006+ out_dtype == c10::kHalf ,
5007+ " dequantize_nf4 only supports float, bfloat16 or float16 output, but got " ,
5008+ out_dtype);
5009+ auto N = t.size (0 );
5010+ auto K = t.size (1 ) * 2 ;
5011+ using Tcomp = float ;
5012+ constexpr auto VEC_LEN = sizeof (__m512i) / sizeof (Tcomp);
5013+ if (K % VEC_LEN == 0 && (group_size >= VEC_LEN || group_size < 0 )) {
5014+ auto n_groups = K / group_size;
5015+ torch::Tensor out = torch::empty ({N, K}, out_dtype);
5016+ product_dispatcher<
5017+ std::tuple<c10::ScalarType, c10::ScalarType>,
5018+ std::tuple<
5019+ enumerate_dispatcher<
5020+ c10::ScalarType,
5021+ c10::kFloat ,
5022+ c10::kBFloat16 ,
5023+ c10::kHalf >,
5024+ enumerate_dispatcher<
5025+ c10::ScalarType,
5026+ c10::kFloat ,
5027+ c10::kBFloat16 ,
5028+ c10::kHalf >>>::
5029+ call (
5030+ std::make_tuple (out_dtype, scales.scalar_type ()),
5031+ [&](auto tuple) {
5032+ // Note: we don't use the `Dequantize` utilities because they are
5033+ // designed for packed layout but we are handling plain layout
5034+ // here.
5035+ auto out_dtype_ = std::get<0 >(tuple);
5036+ auto scales_dtype_ = std::get<1 >(tuple);
5037+ using T =
5038+ typename c10::impl::ScalarTypeToCPPType<out_dtype_>::type;
5039+ using Tscale =
5040+ typename c10::impl::ScalarTypeToCPPType<scales_dtype_>::type;
5041+ using VT = typename VecType<Tcomp>::type;
5042+ using V = VecOps<VT>;
5043+ VT lut = V::set_nf4_lut ();
5044+ constexpr auto k_step = VEC_LEN / 2 ;
5045+ auto pt = GetVLAPtr<uint8_t >(
5046+ (uint8_t *)t.data_ptr (), {t.size (1 ) / k_step, k_step});
5047+ auto pscales = group_size < 0
5048+ ? GetVLAPtr<Tscale>(scales, {1 })
5049+ : GetVLAPtr<Tscale>(scales, {n_groups});
5050+ auto out_ptr = GetVLAPtr<T>(out, {K / VEC_LEN, VEC_LEN});
5051+
5052+ auto dequant_loop = ThreadedLoop<2 >({{N}}, /* loop_scheme */ " A" );
5053+ dequant_loop (
5054+ [&](int * idx) {
5055+ int n = idx[0 ];
5056+ for (int k = 0 ; k < t.size (1 ); k += k_step) {
5057+ // Load 64 bits of nf4 data and a single scale data
5058+ auto p = pt[n][k / k_step];
5059+ auto scale_idx = group_size < 0 ? 0 : k * 2 / group_size;
5060+ auto vscales = V::set1 ((float )pscales[n][scale_idx]);
5061+ uint64_t packed = reinterpret_cast <uint64_t *>(p)[0 ];
5062+ // unpack nf4 data to 32-bit integers
5063+ uint64_t high = 0 ;
5064+ uint64_t low = 0 ;
5065+ for (int i = 0 ; i < 8 ; ++i) {
5066+ low |= ((packed >> (i * 4 )) & 0xf ) << (i * 8 );
5067+ high |= ((packed >> (i * 4 + 32 )) & 0xf ) << (i * 8 );
5068+ }
5069+ __m128i packed_128 = _mm_set_epi64x (high, low);
5070+ __m512i vint32 = _mm512_cvtepu8_epi32 (packed_128);
5071+ // Table look-up
5072+ __m512 vout = _mm512_permutexvar_ps (vint32, lut);
5073+ // Apply scale
5074+ vout = V::mul (vout, vscales);
5075+ // Store results
5076+ auto pout = out_ptr[n][k / k_step];
5077+ if constexpr (std::is_same<T, float >()) {
5078+ _mm512_storeu_ps (pout, vout);
5079+ } else if constexpr (std::is_same<T, at::BFloat16>()) {
5080+ _mm256_storeu_si256 (
5081+ (__m256i*)pout, cvt_fp32_to_bf16 (vout));
5082+ } else if constexpr (std::is_same<T, at::Half>()) {
5083+ _mm256_storeu_si256 (
5084+ (__m256i*)pout, cvt_fp32_to_fp16 (vout));
5085+ } else {
5086+ TORCH_CHECK (false , " Unexpected dtype" );
5087+ }
5088+ }
5089+ },
5090+ [&]() {},
5091+ [&]() {});
5092+ },
5093+ [&](auto out_dtype_) { failing_fallback (); });
5094+ return out;
5095+ }
5096+ auto out = dequantize_woq_weight (
5097+ t, {N, K}, scales.unsqueeze (-1 ), at::Tensor (), WOQ_DTYPE_NF4, group_size);
5098+ return out.to (out_dtype);
5099+ }
5100+
49845101#else // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
49855102
49865103#define SMALL_BATCH_THRESHOLD 32
@@ -5087,6 +5204,19 @@ at::Tensor dequantize_int4_weight_to_int8_packed(
50875204 at::Tensor& compensation) {
50885205 return qw_packed;
50895206}
5207+
5208+ at::Tensor dequantize_nf4 (
5209+ const at::Tensor& t,
5210+ const at::Tensor& scales,
5211+ int64_t group_size,
5212+ c10::ScalarType out_dtype) {
5213+ auto N = t.size (0 );
5214+ auto K = t.size (1 ) * 2 ;
5215+ auto out = dequantize_woq_weight (
5216+ t, {N, K}, scales.unsqueeze (-1 ), at::Tensor (), WOQ_DTYPE_NF4, group_size);
5217+ return out.to (out_dtype);
5218+ }
5219+
50905220#endif // defined(CPU_CAPABILITY_AVX512_FP16) && defined(COMPILER_PREREQ_MET)
50915221
50925222} // namespace
@@ -5097,6 +5227,7 @@ IPEX_REGISTER_DISPATCH(woq_tpp_gemm_unpackB_stub, &qlinear_woq_unpack);
50975227IPEX_REGISTER_DISPATCH (
50985228 woq_dequant_int4_to_int8_packed_stub,
50995229 &dequantize_int4_weight_to_int8_packed);
5230+ IPEX_REGISTER_DISPATCH (dequant_nf4_stub, &dequantize_nf4);
51005231
51015232} // namespace cpu
51025233} // namespace torch_ipex
0 commit comments