Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 0ebd890

Browse files
committed
update row_major for origin PVC/ARC template
1 parent e7f2716 commit 0ebd890

File tree

12 files changed

+96
-88
lines changed

12 files changed

+96
-88
lines changed

include/common/core/base_consts.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ namespace gpu::xetla {
2525

2626
/// @addtogroup xetla_core_base_consts
2727
/// @{
28-
enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP };
2928
/// @} xetla_core_base_consts
3029

3130
} // namespace gpu::xetla

include/common/core/common_types.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,13 @@ enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 };
2626
enum class grf_mode : uint8_t { normal = 0, double_grf = 1 };
2727

2828
enum class mem_layout : uint8_t { row_major = 0, col_major = 1 };
29+
30+
enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 };
31+
32+
struct quant_info {
33+
quant_mode quant_mode;
34+
uint32_t dequant_s;
35+
mem_layout weight_mem_layout;
36+
};
37+
2938
} // namespace gpu::xetla

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ template <
3131
typename perf_tuning_knob_,
3232
typename dtype_scale_,
3333
typename dtype_zero_pt_,
34-
quant_mode quant_type_,
35-
uint32_t dequant_s_,
34+
quant_info quant_info_,
3635
mma_engine mma_engine_ = mma_engine::xmx,
3736
gpu_arch arch_tag_ = gpu_arch::XeHpc,
3837
typename enable = void>
@@ -44,17 +43,15 @@ template <
4443
typename perf_tuning_knob_,
4544
typename dtype_scale_,
4645
typename dtype_zero_pt_,
47-
quant_mode quant_type_,
48-
int dequant_s_,
46+
quant_info quant_info_,
4947
mma_engine mma_engine_,
5048
gpu_arch arch_tag_>
5149
struct compute_policy_int4_dequantize<
5250
compute_attr_,
5351
perf_tuning_knob_,
5452
dtype_scale_,
5553
dtype_zero_pt_,
56-
quant_type_,
57-
dequant_s_,
54+
quant_info_,
5855
mma_engine_,
5956
arch_tag_,
6057
std::enable_if_t<mma_engine_ == mma_engine::xmx>> {
@@ -70,17 +67,17 @@ struct compute_policy_int4_dequantize<
7067
static constexpr mma_engine mma_engine = mma_engine_;
7168
static constexpr gpu_arch arch_tag = arch_tag_;
7269

73-
static_assert(arch_has_xmx<arch_tag>(), "XeLpg does not support xmx");
70+
static_assert(arch_has_xmx<arch_tag>, "XeLpg does not support xmx");
7471

7572
static constexpr bool is_int4_matB_policy = true;
7673

77-
static constexpr uint32_t dequant_s = dequant_s_;
74+
static constexpr uint32_t dequant_s = quant_info_.dequant_s;
7875
static_assert(
7976
(dequant_s % (32 / sizeof(dtype_mma_b))) == 0,
8077
"dequant_s should be a multiply of 32B");
8178
using dtype_scale = dtype_scale_;
8279
using dtype_zero_pt = dtype_zero_pt_;
83-
static constexpr quant_mode quant_type = quant_type_;
80+
static constexpr quant_mode quant_mode = quant_info_.quant_mode;
8481

8582
static constexpr uint32_t block_size_y_a = 16;
8683
using mma_attr = mma_attr_t<arch_tag_, block_size_y_a>;
@@ -103,17 +100,15 @@ template <
103100
typename perf_tuning_knob_,
104101
typename dtype_scale_,
105102
typename dtype_zero_pt_,
106-
quant_mode quant_type_,
107-
int dequant_s_,
103+
quant_info quant_info_,
108104
mma_engine mma_engine_,
109105
gpu_arch arch_tag_>
110106
struct compute_policy_int4_dequantize<
111107
compute_attr_,
112108
perf_tuning_knob_,
113109
dtype_scale_,
114110
dtype_zero_pt_,
115-
quant_type_,
116-
dequant_s_,
111+
quant_info_,
117112
mma_engine_,
118113
arch_tag_,
119114
std::enable_if_t<mma_engine_ == mma_engine::fpu>> {
@@ -131,20 +126,22 @@ struct compute_policy_int4_dequantize<
131126

132127
static constexpr bool is_int4_matB_policy = true;
133128

134-
static constexpr uint32_t dequant_s = dequant_s_;
129+
static constexpr uint32_t dequant_s = quant_info_.dequant_s;
135130
static_assert(
136131
(dequant_s % (32 / sizeof(dtype_mma_b))) == 0,
137132
"dequant_s should be a multiply of 32B");
138133
using dtype_scale = dtype_scale_;
139134
using dtype_zero_pt = dtype_zero_pt_;
140-
static constexpr quant_mode quant_type = quant_type_;
135+
static constexpr quant_mode quant_mode = quant_info_.quant_mode;
136+
static constexpr bool is_col_major_b =
137+
quant_info_.weight_mem_layout == mem_layout::col_major;
141138

142-
static constexpr uint32_t block_size_y_a = 4;
143-
static constexpr uint32_t block_bytes_x_a = 256;
139+
static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
140+
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
144141
static constexpr uint32_t block_size_x_a =
145142
block_bytes_x_a / sizeof(dtype_mma_a);
146-
static constexpr uint32_t block_size_x_b = 1;
147-
static constexpr uint32_t block_bytes_y_b = 256;
143+
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
144+
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
148145
static constexpr uint32_t block_size_y_b =
149146
block_bytes_y_b / sizeof(dtype_mma_b);
150147

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ template <
3636
typename mem_desc_b_t_,
3737
typename dtype_scale_,
3838
typename dtype_zero_pt_,
39-
uint32_t dequant_s_,
40-
quant_mode quant_type_,
39+
quant_info quant_info_,
4140
mma_engine mma_engine_,
4241
typename pre_processing_t_,
4342
gpu_arch arch_tag_>
@@ -47,8 +46,7 @@ class gemm_t<
4746
perf_tuning_knob_,
4847
dtype_scale_,
4948
dtype_zero_pt_,
50-
quant_type_,
51-
dequant_s_,
49+
quant_info_,
5250
mma_engine_,
5351
arch_tag_>,
5452
tile_shape_, // tile shape of workgroup-level gemm
@@ -66,8 +64,7 @@ class gemm_t<
6664
perf_tuning_knob_,
6765
dtype_scale_,
6866
dtype_zero_pt_,
69-
quant_type_,
70-
dequant_s_,
67+
quant_info_,
7168
mma_engine_,
7269
arch_tag_>;
7370
static constexpr uint32_t k_stride = compute_policy::k_stride;
@@ -80,6 +77,7 @@ class gemm_t<
8077

8178
constexpr static gpu_arch arch_tag = compute_policy::arch_tag;
8279
static constexpr uint32_t dequant_s = compute_policy::dequant_s;
80+
static constexpr quant_mode quant_mode = compute_policy::quant_mode;
8381
using dtype_b = typename mem_desc_b_t::dtype;
8482
using dtype_zero_pt = typename compute_policy::dtype_zero_pt;
8583
static constexpr uint32_t pack_ratio = sizeof(dtype_b) * 2;
@@ -328,7 +326,7 @@ class gemm_t<
328326
scale_t,
329327
zero_pt_t,
330328
dequant_s,
331-
quant_type_>;
329+
quant_mode>;
332330
static constexpr bool enable_periodic_sync = (sync_freq != 0);
333331
static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0;
334332
static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0;
@@ -531,7 +529,7 @@ class gemm_t<
531529
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
532530
scale_prefetch_payload);
533531
if constexpr (
534-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
532+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
535533
// TODO 1D prefetch need pack to U32/U64
536534
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
537535
zero_pt_prefetch_payload);
@@ -545,7 +543,7 @@ class gemm_t<
545543
scale_prefetch_payload.template update_tdesc<update_dir_b>(
546544
scale_t::tile_size_y);
547545
if constexpr (
548-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
546+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
549547
zero_pt_prefetch_payload
550548
.template update_tdesc<tdesc_update_dir::y_dir>(
551549
zero_pt_t::tile_size_y);
@@ -575,7 +573,7 @@ class gemm_t<
575573
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
576574
scale, scale_payload);
577575
if constexpr (
578-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
576+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
579577
subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
580578
zero_pt, zero_pt_payload);
581579
}
@@ -590,7 +588,7 @@ class gemm_t<
590588
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
591589
scale_prefetch_payload);
592590
if constexpr (
593-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
591+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
594592
// TODO 1D prefetch need pack to U32/U64
595593
subgroup::tile_prefetch<cache_hint::cached, cache_hint::cached>(
596594
zero_pt_prefetch_payload);
@@ -604,7 +602,7 @@ class gemm_t<
604602
scale_payload.template update_tdesc<update_dir_b>(scale_t::tile_size_y);
605603
}
606604
if constexpr (
607-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
605+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
608606
if (tile_k_idx % zero_pt_addr_update_freq == 0) {
609607
zero_pt_payload.template update_tdesc<tdesc_update_dir::y_dir>(
610608
zero_pt_t::tile_size_y);
@@ -619,7 +617,7 @@ class gemm_t<
619617
scale_prefetch_payload.template update_tdesc<tdesc_update_dir::y_dir>(
620618
scale_t::tile_size_y);
621619
if constexpr (
622-
compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
620+
compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
623621
zero_pt_prefetch_payload
624622
.template update_tdesc<tdesc_update_dir::y_dir>(
625623
zero_pt_t::tile_size_y);
@@ -717,7 +715,7 @@ class gemm_t<
717715
// (offset_y_in_tile) / dequant_s * scale_t::block_size_x +
718716
// offset_x_in_tile;
719717

720-
// if constexpr (compute_policy::quant_type ==
718+
// if constexpr (compute_policy::quant_mode ==
721719
// quant_mode::S4_ASYM) {
722720
// uint32_t zero_pt_idx =
723721
// offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
@@ -739,7 +737,7 @@ class gemm_t<
739737
// cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b +
740738
// ii) - zero_pt_i8;
741739
// } else if constexpr (
742-
// compute_policy::quant_type ==
740+
// compute_policy::quant_mode ==
743741
// quant_mode::S4_FULLRANGE_NO_ZP) {
744742
// cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
745743
// cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b +
@@ -791,7 +789,7 @@ class gemm_t<
791789
xetla_vector<uint8_t, block_size_x_b * block_size_y_b> cvt_blk;
792790
793791
xetla_vector<int32_t, block_size_x_b * block_size_y_b> cvt_blk_i32;
794-
if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) {
792+
if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) {
795793
auto zero_pt_vec = zero_pt.reg
796794
.xetla_select<zero_pt_t::block_size_x, 1>(
797795
scale_block_id * zero_pt_t::block_size_x)
@@ -815,7 +813,7 @@ class gemm_t<
815813
zero_pt_blk.xetla_format<int8_t>());
816814
}
817815
if constexpr (
818-
compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) {
816+
compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
819817
xetla_vector<int8_t, block_size_x_b * block_size_y_b> cvt_blk_i8;
820818
cvt_blk_i8.xetla_select<matB_t::block_elems, 2>(0) =
821819
matB_blk & 0x0f;

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class gemm_universal_t<
159159
/// @brief GEMM arguments.
160160
/// This is the interface for users to pass the application-related runtime
161161
/// variables.
162-
template <quant_mode quant_mode = S4_FULLRANGE_NO_ZP>
162+
template <quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP>
163163
struct arguments_t {
164164
/// @brief Is the size of the m dimension of the matrix multiplication (m x
165165
/// k x n).
@@ -295,7 +295,7 @@ class gemm_universal_t<
295295
}
296296
};
297297
template <>
298-
struct arguments_t<S4_FULLRANGE_NO_ZP> {
298+
struct arguments_t<quant_mode::S4_FULLRANGE_NO_ZP> {
299299
/// @brief Is the size of the m dimension of the matrix multiplication (m x
300300
/// k x n).
301301
uint32_t matrix_m;
@@ -566,7 +566,7 @@ class gemm_universal_t<
566566
implementable &=
567567
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));
568568
if constexpr (
569-
gemm_t::compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) {
569+
gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) {
570570
implementable &= (args.zero_pt_ld % pack_ratio == 0);
571571
}
572572

@@ -668,7 +668,7 @@ class gemm_universal_t<
668668
uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride;
669669
gemm_args_t gemm_args;
670670
if constexpr (
671-
gemm_t::compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) {
671+
gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
672672
gemm_args = gemm_args_t(
673673
mem_desc_a,
674674
mem_desc_b,

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ tile_load(tile_t& tile, payload_t& payload) {
100100
static constexpr bool reg_transpose = tile_desc::reg_transpose;
101101

102102
static constexpr bool mem_transpose = payload_t::mem_transpose;
103-
static constexpr bool trans = reg_transpose ^ mem_transpose;
103+
static constexpr bool trans = payload_t::trans;
104104
static constexpr uint32_t scale_factor = payload_t::scale_factor;
105105

106106
static constexpr bool mem_transform = payload_t::mem_transform;
@@ -535,9 +535,7 @@ tile_load(tile_t& tile, payload_t& payload) {
535535
// }
536536
}
537537

538-
if constexpr (
539-
payload_t::trans &&
540-
!(std::is_same_v<dtype, int4x2> || std::is_same_v<dtype, int4x8>)) {
538+
if constexpr (payload_t::trans) {
541539
SW_BARRIER();
542540
tile_transpose(tile);
543541
}
@@ -604,9 +602,7 @@ tile_load(tile_t& tile, payload_t& payload) {
604602
}
605603
}
606604

607-
if constexpr (
608-
payload_t::trans &&
609-
!(std::is_same_v<dtype, int4x2> || std::is_same_v<dtype, int4x8>)) {
605+
if constexpr (payload_t::trans) {
610606
SW_BARRIER();
611607
tile_transpose(tile);
612608
}

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,14 @@ struct mem_payload_t<
6565
mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag>;
6666

6767
public:
68-
static constexpr bool mem_transpose = memory_layout == mem_layout::col_major;
68+
static constexpr bool mem_transpose =
69+
memory_layout == mem_layout::col_major &&
70+
!(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
6971

7072
static constexpr reg_layout register_layout = tile_desc::register_layout;
7173
static constexpr bool reg_transpose =
7274
register_layout == reg_layout::transpose_tiled;
75+
7376
static constexpr bool trans = mem_transpose ^ reg_transpose;
7477

7578
static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose &&

include/subgroup/tile/impl/tile_op_functor.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ template <
5757
typename scale_t,
5858
typename zero_pt_t,
5959
uint32_t dequant_s,
60-
quant_mode quant_type>
60+
quant_mode quant_mode>
6161
struct dequant_int4_weight_t {
6262
struct arguments_t {
6363
uint32_t wg_start_m;
@@ -130,7 +130,7 @@ struct dequant_int4_weight_t {
130130
(offset_y_in_tile) / dequant_s * scale_t::block_size_x +
131131
offset_x_in_tile;
132132

133-
if constexpr (quant_type == quant_mode::S4_ASYM) {
133+
if constexpr (quant_mode == quant_mode::S4_ASYM) {
134134
uint32_t zero_pt_idx =
135135
offset_y_in_tile / dequant_s * zero_pt_t::block_size_x +
136136
offset_x_in_tile / pack_ratio;
@@ -149,7 +149,7 @@ struct dequant_int4_weight_t {
149149
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
150150
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
151151
zero_pt_i8;
152-
} else if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) {
152+
} else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) {
153153
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) =
154154
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
155155
int8_t(8);

0 commit comments

Comments
 (0)