Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit b2dfad5

Browse files
committed
save(fix HPC 2D load)
1 parent 0ebd890 commit b2dfad5

File tree

26 files changed

+236
-186
lines changed

26 files changed

+236
-186
lines changed

examples/08_scaled_dot_product_attention/softmax.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,21 @@ struct xetla_softmax_fwd_t {
6060
using softmax_tile_desc_t = subgroup::
6161
tile_desc_t<SIMD, block_height, SIMD, block_height, reg_layout::tiled>;
6262
using softmax_load_t = subgroup::tile_t<dtype_in, softmax_tile_desc_t>;
63+
using mem_desc_in_t = mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>;
6364
using softmax_load_payload_t = subgroup::mem_payload_t<
64-
mem_desc_t<dtype_in, mem_layout::row_major, mem_space_in>,
65+
mem_desc_in_t,
6566
softmax_tile_desc_t,
66-
subgroup::msg_type_v<softmax_tile_desc_t, mem_space_in>,
67+
subgroup::msg_type_v<softmax_tile_desc_t, mem_desc_in_t>,
6768
arch_tag>;
6869

6970
// this tile will store the softmax result to global memory
7071
using softmax_store_t = subgroup::tile_t<dtype_out, softmax_tile_desc_t>;
72+
using mem_desc_out_t =
73+
mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>;
7174
using softmax_store_payload_t = subgroup::mem_payload_t<
72-
mem_desc_t<dtype_out, mem_layout::row_major, mem_space_out>,
75+
mem_desc_out_t,
7376
softmax_tile_desc_t,
74-
subgroup::msg_type_v<softmax_tile_desc_t, mem_space_out>,
77+
subgroup::msg_type_v<softmax_tile_desc_t, mem_desc_out_t>,
7578
arch_tag>;
7679

7780
struct arguments_t {

examples/09_gate_recurrent_unit/kernel_func.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ struct gru_layer {
156156
using mat_hidden_payload_t = mem_payload_t<
157157
mem_desc_a_t,
158158
matC_tile_desc_t,
159-
msg_type_v<matC_tile_desc_t, mem_loc_input>,
159+
msg_type_v<matC_tile_desc_t, mem_desc_a_t>,
160160
gpu_arch::XeHpc>;
161161
using matC_payload_t = mem_payload_t<
162162
mem_desc_c_t,

include/common/core/memory.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,8 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
356356
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
357357
__ESIMD_NS::alignment<alignment>};
358358
if constexpr (sizeof(T) * N < sizeof(uint32_t)) {
359-
auto padding_load = __ESIMD_NS::block_load<T, sizeof(uint32_t) / sizeof(T)>(
360-
ptr, byte_offset, props);
361-
return padding_load.xetla_select<N, 1>(0);
359+
xetla_vector<uint32_t, N> offsets(byte_offset, sizeof(T));
360+
return __ESIMD_NS::gather<T, N, uint32_t>(ptr, offsets);
362361
} else {
363362
return __ESIMD_NS::block_load<T, N>(ptr, byte_offset, props);
364363
}
@@ -501,7 +500,13 @@ __XETLA_API void xetla_store_global(
501500
__ESIMD_NS::cache_hint_L1<gpu::xetla::detail::get_cache_hint(L1H)>,
502501
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
503502
__ESIMD_NS::alignment<alignment>};
504-
__ESIMD_NS::block_store<T, N>(ptr, byte_offset, vals, props);
503+
504+
if constexpr (sizeof(T) * N < sizeof(uint32_t)) {
505+
xetla_vector<uint32_t, N> offsets(byte_offset, sizeof(T));
506+
return __ESIMD_NS::scatter<T, N, uint32_t>(ptr, offsets, vals);
507+
} else {
508+
__ESIMD_NS::block_store<T, N>(ptr, byte_offset, vals, props);
509+
}
505510
}
506511

507512
/// @brief Stateless scattered atomic (0 src).

include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,12 @@ struct row_reduction_fused_op_t<
139139
block_size_y,
140140
reg_layout::tiled>;
141141
using dgelu_w_in_t = subgroup::tile_t<dtype_in, dgelu_tile_desc_t>;
142+
using mem_desc_in_t =
143+
mem_desc_t<dtype_in, mem_layout::row_major, mem_space::global>;
142144
using dgelu_w_in_payload_t = subgroup::mem_payload_t<
143-
mem_desc_t<dtype_in, mem_layout::row_major, mem_space::global>,
145+
mem_desc_in_t,
144146
dgelu_tile_desc_t,
145-
subgroup::msg_type_v<dgelu_tile_desc_t, mem_space::global>,
147+
subgroup::msg_type_v<dgelu_tile_desc_t, mem_desc_in_t>,
146148
gpu_arch::XeHpc>;
147149
using dgelu_x_out_t = subgroup::tile_t<dtype_out, dgelu_tile_desc_t>;
148150
using dgelu_x_out_payload_t = subgroup::mem_payload_t<
@@ -234,17 +236,21 @@ struct row_reduction_fused_op_t<
234236
block_size_y,
235237
reg_layout::tiled>;
236238
using mask_in_t = subgroup::tile_t<dtype_mask, reduction_tile_desc_t>;
239+
using mem_desc_mask_t =
240+
mem_desc_t<dtype_mask, mem_layout::row_major, mem_space::global>;
237241
using mask_in_payload_t = subgroup::mem_payload_t<
238-
mem_desc_t<dtype_mask, mem_layout::row_major, mem_space::global>,
242+
mem_desc_mask_t,
239243
reduction_tile_desc_t,
240-
subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
244+
subgroup::msg_type_v<reduction_tile_desc_t, mem_desc_mask_t>,
241245
gpu_arch::XeHpc>;
242246
using dropout_bwd_out_t =
243247
subgroup::tile_t<dtype_out, reduction_tile_desc_t>;
248+
using mem_desc_out_t =
249+
mem_desc_t<dtype_out, mem_layout::row_major, mem_space::global>;
244250
using dropout_bwd_out_payload_t = subgroup::mem_payload_t<
245-
mem_desc_t<dtype_out, mem_layout::row_major, mem_space::global>,
251+
mem_desc_out_t,
246252
reduction_tile_desc_t,
247-
subgroup::msg_type_v<reduction_tile_desc_t, mem_space::global>,
253+
subgroup::msg_type_v<reduction_tile_desc_t, mem_desc_out_t>,
248254
gpu_arch::XeHpc>;
249255
if (dropout_prob != 0) {
250256
mask_in_t mask_in;

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class gemm_t<
174174
using matA_payload_t = subgroup::mem_payload_t<
175175
mem_desc_a_t,
176176
matA_tile_desc_t,
177-
subgroup::msg_type_v<matA_tile_desc_t, mem_space_a, mem_desc_a_t::layout>,
177+
subgroup::msg_type_v<matA_tile_desc_t, mem_desc_a_t>,
178178
arch_tag>;
179179
using matA_acc_t = subgroup::tile_t<dtype_mma_a, matA_tile_desc_t>;
180180
using matA_prefetch_payload_t = subgroup::
@@ -204,9 +204,13 @@ class gemm_t<
204204
using matB_payload_t = subgroup::mem_payload_t<
205205
mem_desc_b_t,
206206
matB_tile_desc_t,
207-
subgroup::msg_type_v<matB_tile_desc_t, mem_space_b>,
208-
// subgroup::msg_type_v<matB_tile_desc_t, mem_space_b,
209-
// mem_desc_b_t::layout>,
207+
subgroup::msg_type_v<
208+
matB_tile_desc_t,
209+
mem_desc_t<
210+
typename mem_desc_b_t::dtype,
211+
mem_layout::row_major,
212+
mem_desc_b_t::space>>,
213+
// subgroup::msg_type_v<matB_tile_desc_t, mem_desc_b_t>,
210214
arch_tag>;
211215
using matB_prefetch_payload_t = subgroup::
212216
prefetch_payload_t<mem_desc_b_t, matB_tile_desc_t, wg_size_y, arch_tag>;
@@ -282,10 +286,7 @@ class gemm_t<
282286
using scale_payload_t = subgroup::mem_payload_t<
283287
mem_desc_scale_t,
284288
scale_tile_desc_t,
285-
subgroup::msg_type_v<
286-
scale_tile_desc_t,
287-
mem_space::global,
288-
mem_desc_scale_t::layout>,
289+
subgroup::msg_type_v<scale_tile_desc_t, mem_desc_scale_t>,
289290
arch_tag>;
290291

291292
// compress int4 along N dimensions
@@ -300,10 +301,7 @@ class gemm_t<
300301
using zero_pt_payload_t = subgroup::mem_payload_t<
301302
mem_desc_zero_pt_t,
302303
zero_pt_tile_desc_t,
303-
subgroup::msg_type_v<
304-
zero_pt_tile_desc_t,
305-
mem_space::global,
306-
mem_desc_zero_pt_t::layout>,
304+
subgroup::msg_type_v<zero_pt_tile_desc_t, mem_desc_zero_pt_t>,
307305
arch_tag>;
308306
using scale_prefetch_payload_t = subgroup::
309307
prefetch_payload_t<mem_desc_scale_t, scale_tile_desc_t, 1, arch_tag>;

include/experimental/group/reduction/row_reduce_store_xe.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ struct group_row_reduce_store_t<
5858
using local_st_tile_desc_t =
5959
subgroup::tile_desc_t<row_size, 1, block_size_x, 1, reg_layout::tiled>;
6060
using local_st_t = subgroup::tile_t<dtype_acc, local_st_tile_desc_t>;
61+
using mem_desc_acc =
62+
mem_desc_t<dtype_acc, mem_layout::row_major, mem_space::local>;
6163
using local_st_payload_t = subgroup::mem_payload_t<
62-
mem_desc_t<dtype_acc, mem_layout::row_major, mem_space::local>,
64+
mem_desc_acc,
6365
local_st_tile_desc_t,
64-
subgroup::msg_type_v<local_st_tile_desc_t, mem_space::local>,
66+
subgroup::msg_type_v<local_st_tile_desc_t, mem_desc_acc>,
6567
gpu_arch::XeHpc>;
6668
using local_ld_tile_desc_t = subgroup::tile_desc_t<
6769
local_tile_size_x,
@@ -70,10 +72,12 @@ struct group_row_reduce_store_t<
7072
wg_size_y,
7173
reg_layout::tiled>;
7274
using local_ld_t = subgroup::tile_t<dtype_acc, local_ld_tile_desc_t>;
75+
using mem_desc_ld_t =
76+
mem_desc_t<dtype_acc, mem_layout::row_major, mem_space::local>;
7377
using local_ld_payload_t = subgroup::mem_payload_t<
74-
mem_desc_t<dtype_acc, mem_layout::row_major, mem_space::local>,
78+
mem_desc_ld_t,
7579
local_ld_tile_desc_t,
76-
subgroup::msg_type_v<local_ld_tile_desc_t, mem_space::local>,
80+
subgroup::msg_type_v<local_ld_tile_desc_t, mem_desc_ld_t>,
7781
gpu_arch::XeHpc>;
7882

7983
// If the local tile size is small, we still can use 2D block store

include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct col_major_shuf_t<
8383
using store_tile_payload_t = subgroup::mem_payload_t<
8484
mem_desc_store_tile_t,
8585
store_tile_desc_t,
86-
subgroup::msg_type_v<store_tile_desc_t, mem_space::global>,
86+
subgroup::msg_type_v<store_tile_desc_t, mem_desc_store_tile_t>,
8787
arch_>;
8888

8989
using mem_desc_gidx_t = mem_desc_t<
@@ -97,7 +97,7 @@ struct col_major_shuf_t<
9797
using gidx_payload_t = subgroup::mem_payload_t<
9898
mem_desc_gidx_t,
9999
gidx_tile_desc_t,
100-
subgroup::msg_type_v<gidx_tile_desc_t, mem_space::global>,
100+
subgroup::msg_type_v<gidx_tile_desc_t, mem_desc_gidx_t>,
101101
arch_>;
102102

103103
struct arguments_t {

include/experimental/kernel/data_transformer/data_transformer_xe.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,11 @@ struct xetla_data_transformer<
122122
block_size_y,
123123
in_reg_layout>;
124124
using global_ld_t = subgroup::tile_t<dtype_in, global_ld_tile_desc_t>;
125+
using mem_desc_ld_t = mem_desc_t<dtype_in, mem_layout_in, mem_space::global>;
125126
using global_ld_payload_t = subgroup::mem_payload_t<
126-
mem_desc_t<dtype_in, mem_layout_in, mem_space::global>,
127+
mem_desc_ld_t,
127128
global_ld_tile_desc_t,
128-
subgroup::msg_type_v<global_ld_tile_desc_t, mem_space::global>,
129+
subgroup::msg_type_v<global_ld_tile_desc_t, mem_desc_ld_t>,
129130
gpu_arch::XeHpc>;
130131

131132
using global_st_tile_desc_t = subgroup::tile_desc_t<

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -550,18 +550,18 @@ class gemm_universal_t<
550550
args.matB_base.base, args.matB_ld / pack_ratio);
551551
}
552552
}
553-
if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
554-
if (epilogue_t::msg_type_c == msg_type::block_2d) {
555-
implementable &= kernel::block_2d<arch_tag, dtype_c>::check_tensor(
556-
(uint64_t)(args.matC_base.base),
557-
args.matrix_n,
558-
args.matrix_m,
559-
args.matC_ld);
560-
} else {
561-
implementable &= kernel::general_1d<arch_tag, dtype_c>::check_alignment(
562-
args.matC_base.base, args.matC_ld);
563-
}
564-
}
553+
// if (epilogue_t::msg_type_c != msg_type::unaligned_2d) {
554+
// if (epilogue_t::msg_type_c == msg_type::block_2d) {
555+
// implementable &= kernel::block_2d<arch_tag, dtype_c>::check_tensor(
556+
// (uint64_t)(args.matC_base.base),
557+
// args.matrix_n,
558+
// args.matrix_m,
559+
// args.matC_ld);
560+
// } else {
561+
// implementable &= kernel::general_1d<arch_tag, dtype_c>::check_alignment(
562+
// args.matC_base.base, args.matC_ld);
563+
// }
564+
// }
565565
// check for int4x2
566566
implementable &=
567567
((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0));

include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,26 @@ struct layer_norm_bwd_t<
9292
using gamma_in_t = subgroup::tile_t<dtype_weight, ln_bwd_tile_desc_t>;
9393
using dx_out_t = subgroup::tile_t<dtype_x, ln_bwd_tile_desc_t>;
9494

95+
using mem_desc_y_t =
96+
mem_desc_t<dtype_y, mem_layout::row_major, mem_space::global>;
9597
using dy_in_payload_t = subgroup::mem_payload_t<
96-
mem_desc_t<dtype_y, mem_layout::row_major, mem_space::global>,
98+
mem_desc_y_t,
9799
ln_bwd_tile_desc_t,
98-
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
100+
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_desc_y_t>,
99101
gpu_arch::XeHpc>;
102+
using mem_desc_x_t =
103+
mem_desc_t<dtype_x, mem_layout::row_major, mem_space::global>;
100104
using x_in_payload_t = subgroup::mem_payload_t<
101-
mem_desc_t<dtype_x, mem_layout::row_major, mem_space::global>,
105+
mem_desc_x_t,
102106
ln_bwd_tile_desc_t,
103-
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
107+
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_desc_x_t>,
104108
gpu_arch::XeHpc>;
109+
using mem_desc_weight_t =
110+
mem_desc_t<dtype_weight, mem_layout::row_major, mem_space::global>;
105111
using gamma_in_payload_t = subgroup::mem_payload_t<
106-
mem_desc_t<dtype_weight, mem_layout::row_major, mem_space::global>,
112+
mem_desc_weight_t,
107113
ln_bwd_tile_desc_t,
108-
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_space::global>,
114+
subgroup::msg_type_v<ln_bwd_tile_desc_t, mem_desc_weight_t>,
109115
gpu_arch::XeHpc>;
110116
using dx_out_payload_t = subgroup::mem_payload_t<
111117
mem_desc_t<dtype_x, mem_layout::row_major, mem_space::global>,

0 commit comments

Comments
 (0)