From 6a94b64a3646f80f1d2b97dc86ea83e67f499e41 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 14 Jan 2021 14:39:04 +0800 Subject: [PATCH 01/40] add 1st version of nhwc --- config/igemm_fwd_gtc_gfx908_nhwc.config | 46 + igemm/algo/__init__.py | 1 + igemm/algo/igemm_base.py | 17 +- igemm/algo/igemm_fwd_gtc_nhwc.py | 1783 +++++++++++++++++++++++ igemm/codegen/compile.py | 2 +- igemm/igemm_codegen_driver.py | 5 +- 6 files changed, 1847 insertions(+), 7 deletions(-) create mode 100644 config/igemm_fwd_gtc_gfx908_nhwc.config create mode 100755 igemm/algo/igemm_fwd_gtc_nhwc.py diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config new file mode 100644 index 00000000..2c664ca3 --- /dev/null +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -0,0 +1,46 @@ +[codegen] +arch = 'gfx908' +code_object = 'cov3' +mode = 'flat' + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxN0xN1B +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxN0xN1B +tensor_b_thread_lengths = [1, 4, 2] # ExCxK +tensor_b_cluster_lengths = [1, 4, 64] # ExCxK +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 4 +nxe = 0 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxN0xN1B +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxN0xN1B +tensor_b_thread_lengths = [1, 4, 2] # ExCxK +tensor_b_cluster_lengths = [1, 4, 64] # ExCxK +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 4 +nxe = 1 diff --git a/igemm/algo/__init__.py b/igemm/algo/__init__.py index 80788a52..985cdf0f 100755 --- a/igemm/algo/__init__.py +++ b/igemm/algo/__init__.py @@ -32,6 +32,7 @@ from .igemm_bwd_gtc import * from .igemm_wrw_gtc import * from .igemm_fwd_gtc import * +from .igemm_fwd_gtc_nhwc import * from .igemm_upsampling_clear import * from .utility import * from .thread_mapping import * diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 749fb29f..e21ed686 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -36,7 +36,7 @@ IGEMM_GTC_FEAT_PRECACHE_SOFFSET = 1 IGEMM_GTC_FEAT_LOCAL_PREFETCH = 1 IGEMM_GTC_FEAT_FMA_INTERLEAVE = 1 -IGEMM_GTC_FEAT_MAGIC_DIVISION = 1 +IGEMM_GTC_FEAT_MAGIC_DIVISION = 0 IGEMM_GTC_FEAT_SOURCE_ACCESS_ENCODING_KERNEL_NAME = 0 # IGEMM_GTC_TENSOR_LAYOUT_NCHW = ((1 << 4) | 0) @@ -185,7 +185,7 @@ def __init__(self, tunable_dict): self.allow_lds_reorder = utility_dict_with_default_t(tunable_dict)('allow_lds_reorder', IGEMM_GTC_FEAT_ALLOW_LDS_REORDER) self.precache_soffset = utility_dict_with_default_t(tunable_dict)('precache_soffset', IGEMM_GTC_FEAT_PRECACHE_SOFFSET) - default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if self.direction == 'fwd' \ + default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if (self.direction == 'fwd' and self.tensor_layout == 'nchw') \ else IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', default_source_access_order) @@ -227,9 +227,16 @@ def _unmerge_x1_from_e(unroll_k, nxe): if self.direction == 'fwd': assert self.gemm_n_per_block % self.nxb == 0 - self.unmerge_sub_n = self.gemm_n_per_block // self.nxb - self.unmerge_sub_k = 1 # not used - self.unmerge_sub_c = _unmerge_x1_from_e(self.gemm_k_per_block, self.nxe) + if self.tensor_layout == 'nchw': + self.unmerge_sub_n = self.gemm_n_per_block // self.nxb + self.unmerge_sub_k = 1 # not used + self.unmerge_sub_c = _unmerge_x1_from_e(self.gemm_k_per_block, self.nxe) + elif self.tensor_layout == 'nhwc': + self.unmerge_sub_n = self.gemm_m_per_block // self.nxb + self.unmerge_sub_k = 1 # not used + self.unmerge_sub_c = 1 # not used + else: + assert False elif self.direction == 'bwd': assert self.gemm_n_per_block % self.nxb == 0 self.unmerge_sub_n = self.gemm_n_per_block // self.nxb diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py new file mode 100755 index 00000000..a16e1cbe --- /dev/null +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -0,0 +1,1783 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2020-2021 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# pylint: disable=maybe-no-member +from ..codegen import * +from .fma_main_loop import * +from .igemm_base import * +from .global_memory import * +from .shared_memory import * +from .utility import * +from .thread_mapping import * +from .xdlops_mapping import * +from .coalescing_store import * +from .mfma_main_loop import * + + +IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B = 0 +IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 = 1 +# IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_N_E_K = 4 +# IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_N_K_E = 5 + + +def _find_non_1_index_in_list(list_object): + result_list = list() + for idx, item in enumerate(list_object): + assert type(item) is int + if item != 1: + result_list.append(idx) + return result_list + +class igemm_fwd_gtc_nhwc_t(mc_base_t): + ''' + tensor a (input) tensor b (wei) + thread_lengths : ta_e, ta_c, ta_n0, ta_n1b, tb_e, tb_c, tb_k + cluster_lengths : ca_e, ca_c, ca_n0, ca_n1b, cb_e, cb_c, cb_k + + for a/b tensor, always load gemm_k dimension first. + + ''' + def __init__(self, mc, tunable): + assert type(tunable) is igemm_gtc_tunable_parameter_t + mc_base_t.__init__(self, mc) + self.tunable = tunable + self.global_load_in = self.global_load_in_t(mc, self) + self.global_load_wei = self.global_load_wei_t(mc, self) + self.shared_store_in = self.shared_store_in_t(mc, self) + self.shared_store_wei = self.shared_store_wei_t(mc, self) + + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + self.in_thread_copy_ndim = len(in_thread_copy_index) + self.wei_thread_copy_ndim = len(wei_thread_copy_index) + assert self.in_thread_copy_ndim in (0, 1, 2) + assert self.wei_thread_copy_ndim in (0, 1, 2) + + + self.coalescing_store_groups = igemm_next_pow2(self.tunable.coalescing_store_groups) + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + assert (self.tunable.gemm_m_per_thread * self.tunable.gemm_m_repeat) % self.coalescing_store_groups == 0, \ + f"coalescing store groups should be divided by thread m {self.tunable.gemm_m_per_thread}x{self.tunable.gemm_m_repeat}" + + ctrl_thread_mapping = ctrl_thread_mapping_t() + # -> MR x NR x ML1 x NL1 x ML0 x NL0 + ctrl_thread_mapping.thread_lengths = [self.tunable.gemm_m_repeat, self.tunable.gemm_n_repeat, 1, 1, self.tunable.gemm_m_per_thread, self.tunable.gemm_n_per_thread] + ctrl_thread_mapping.cluster_lengths = [1, 1, self.tunable.gemm_m_level1_cluster, self.tunable.gemm_n_level1_cluster, self.tunable.gemm_m_level0_cluster, self.tunable.gemm_n_level0_cluster] + self.thread_mapping = igemm_thread_mapping_t(self.mc, ctrl_thread_mapping) + + ctrl_coalescing_store = ctrl_coalescing_store_t() + ctrl_coalescing_store.ctm = ctrl_thread_mapping + ctrl_coalescing_store.coalescing_groups = self.coalescing_store_groups + ctrl_coalescing_store.data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + ctrl_coalescing_store.vector_write_out = 1 # TODO: some cases this can be set to other value + ctrl_coalescing_store.block_size = self.tunable.block_size + + gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + na_c0, na_c1e, na_k0, na_k1, nb_c0, nb_c1e, nb_n0, nb_n1b = self.get_dims_lengths() + ctrl_coalescing_store.gemm_m_m0_m1 = [na_k0, na_k1] + if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K1_K0: + ctrl_coalescing_store.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 + + ctrl_coalescing_store.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust + self.coalescing_store = igemm_coalescing_store_t(mc, ctrl_coalescing_store) + + else: + def flatten(x): + from functools import reduce + return reduce(lambda a, b: a*b, x, 1) + ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile_fp32(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE) + self.xdlops_mapping = igemm_xdlops_mapping_t(self.mc, ctrl_xdlops_mapping) + assert flatten(ctrl_xdlops_mapping.acc_c_per_thread_m()) % self.coalescing_store_groups == 0, \ + f"coalescing store groups should be divided by agpr per thread in m direction {ctrl_xdlops_mapping.acc_c_per_thread_m()}" + + ctrl_coalescing_store_xdlops = ctrl_coalescing_store_xdlops_t() + ctrl_coalescing_store_xdlops.cxm = ctrl_xdlops_mapping + ctrl_coalescing_store_xdlops.coalescing_groups = self.coalescing_store_groups + ctrl_coalescing_store_xdlops.data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + ctrl_coalescing_store_xdlops.vector_write_out = 1 # TODO: some cases this can be set to other value + ctrl_coalescing_store_xdlops.block_size = self.tunable.block_size + + gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + ctrl_coalescing_store_xdlops.gemm_m_m0_m1 = [na_n0, na_n1b] + if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0: + # we may consider not suppor this mode + ctrl_coalescing_store_xdlops.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 + ctrl_coalescing_store_xdlops.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust + self.coalescing_store = igemm_coalescing_store_xdlops_t(mc, ctrl_coalescing_store_xdlops) + + + self.label_out = f"L_{self.name()}_out" + self.dict_shifted_stride = dict() + + self.karg = self.kernel_karg_t(mc, self) + self.sgpr = self.kernel_sgpr_t(mc, self) + self.vgpr = self.kernel_vgpr_t(mc, self) + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self.agpr = self.kernel_agpr_t(mc, self) + + def name(self): + return igemm_gtc_encode_kernel_name(self.tunable) + + def try_shift_stride(self, gpr, shifter): + assert type(gpr) is sym_t + with self._deferred_context(): + if gpr.label not in self.dict_shifted_stride: + self.dict_shifted_stride[gpr.label] = gpr + self._emit(f"s_lshl_b32 s[{gpr()}], s[{gpr()}], {shifter}") + return self._get_deferred() + + def get_lds_gemm_m_gemm_n_order(self): + def need_reverse_order(x0, x1): + if x0 != 1 and x1 == 1: + return True + if x0 > x1: + return True + return False + + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + + gemm_n_order = -1 # gemm_n order is not supported + + gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B + if self.tunable.allow_lds_reorder: + if need_reverse_order(ta_n0, ta_n1b): + gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 + assert False, "maybe not correct" + + return gemm_m_order, gemm_n_order + + class macro_set_flag_hw(macro_base_t): + def __init__(self, mc, inline = False): + macro_base_t.__init__(self, mc, inline) + self.declare_arg("v_flag") + self.declare_arg("v_ih") + self.declare_arg("v_iw") + self.declare_arg("s_h") + self.declare_arg("s_w") + def name(self): + return '.v_fwd_gtc_nhwc_set_flag_hw' + + def expr(self): + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_h()}], v[{self.v_ih()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, 1, vcc") + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") + + class macro_in_update_hw_t(macro_base_t): + def __init__(self, mc, inline = False): + macro_base_t.__init__(self, mc, inline) + self.declare_arg("v_in_ihi") + self.declare_arg("v_in_iwi") + self.declare_arg("v_in_iho") + self.declare_arg("v_in_iwo") + self.declare_arg("v_in_iy") + self.declare_arg("v_in_ix") + self.declare_arg("s_dilation_h") + self.declare_arg("s_dilation_w") + def name(self): + return '.v_fwd_gtc_nhwc_in_update_hw' + + def expr(self): + self._emit(f"; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h, here make sure iho <- iho * s_stride_h - s_pad_h before hand") + self._emit(f"; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w, here make sure iwo <- iwo * s_stride_w - s_pad_w before hand") + self._emit(f"v_mad_i32_i24 v[{self.v_in_ihi()}], s[{self.s_dilation_h()}], v[{self.v_in_iy()}], v[{self.v_in_iho()}]") + self._emit(f"v_mad_i32_i24 v[{self.v_in_iwi()}], s[{self.s_dilation_w()}], v[{self.v_in_ix()}], v[{self.v_in_iwo()}]") + + class macro_in_update_os_t(macro_base_t): + def __init__(self, mc, data_byte, inline = False): + macro_base_t.__init__(self, mc, inline) + self.data_byte = data_byte + self.declare_arg("v_in_os") + self.declare_arg("v_in_os_base") + self.declare_arg("v_in_ihi") + self.declare_arg("v_in_iwi") + self.declare_arg("s_wi") + self.declare_arg("s_in_stride_wi") + self.declare_arg("v_tmp") + def name(self): + return '.v_fwd_gtc_nhwc_in_update_os' + + def expr(self): + self._emit(f"v_mad_u32_u24 v[{self.v_tmp()}], v[{self.v_in_ihi()}], s[{self.s_wi()}], v[{self.v_in_iwi()}]") + self._emit(f"v_mul_lo_u32 v[{self.v_tmp()}], s[{self.s_in_stride_wi()}], v[{self.v_tmp()}]") + self._emit(f"v_add_u32 v[{self.v_in_os()}], v[{self.v_tmp()}], v[{self.v_in_os_base()}]") + + class macro_move_slice_window_k_e1_c_t(macro_base_t): + ''' + nhwc gemm_k = e*c, and thread/cluster length for e is always 1 + hence always move along c and accumulate into e + + this macro is for input and weight together. + ''' + def __init__(self, mc, tunable, inline = False): + macro_base_t.__init__(self, mc, inline) + self.tunable = tunable + #self.declare_arg("v_move_slice_k_iy") + self.declare_arg("v_move_slice_k_ix") + self.declare_arg("v_move_slice_k_ic") + #self.declare_arg("s_gemm_k_num_y") + self.declare_arg("s_gemm_k_num_x") + self.declare_arg("s_gemm_k_num_c") + self.declare_arg("s_move_slice_k_c") + self.declare_arg("v_in_os") + self.declare_arg("v_wei_os") + + self.declare_arg("s_in_stride_c") # this is indeed s_move_slice_k_c * data_byte + self.declare_arg("s_in_stride_gemm_k_num_c") + self.declare_arg("s_in_stride_diff_x") # indeed stride_x - stride_c, always possitive + self.declare_arg("s_in_stride_diff_y") # indeed stride_y - stride_x, always possitive + + self.declare_arg("v_in_ihi") # need update + self.declare_arg("v_in_iwi") # need update + # self.declare_arg("s_dilation_h") + # self.declare_arg("s_dilation_w") + self.declare_arg("s_in_diff_hi") # s_dilation_h + self.declare_arg("s_in_diff_wi") # s_dilation_w + self.declare_arg("s_in_diff_sub_wi") # total wi needed to be deduced from iwi, when carry-on + + def name(self): + return '.v_fwd_gtc_nhwc_move_slice_window_k_e1_c' + + def expr(self): + self._emit(f"v_add_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_move_slice_k_c()}], v[{self.v_move_slice_k_ic()}]") + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_in_stride_c()}], v[{self.v_in_os()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_wei_stride_c()}], v[{self.v_wei_os()}]") # weight offset always increase, treat y*x*c as single dimension + self._emit(f"v_cmpx_le_u32 vcc, s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") + self._emit(f"v_subrev_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") + self._emit(f"v_add_u32 v[{self.v_move_slice_k_ix()}], 1, v[{self.v_move_slice_k_ix()}]") + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_in_stride_diff_x()}], v[{self.v_in_os()}]") # merge with above c + self._emit(f"v_add_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_wi()}], v[{self.v_in_iwi()}]") + self._emit(f"s_mov_b64 exec, -1") + self._emit_empty_line() + self._emit(f"v_cmpx_le_u32 vcc s[{self.s_gemm_k_num_x()}], v[{self.v_move_slice_k_ix()}]") + self._emit(f"v_add_u32 v[{self.v_in_os_base()}], s[{self.s_in_stride_diff_y()}], v[{self.v_in_os_base()}]") + self._emit(f"v_subrev_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_sub_wi()}], v[{self.v_in_iwi()}]") + self._emit(f"v_add_u32 v[{self.v_in_ihi()}], s[{self.s_in_diff_hi()}], v[{self.v_in_ihi()}]") + self._emit(f"s_mov_b64 exec, -1") + self._emit_empty_line() + # free of last dim check + + class global_load_in_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + def get_issues(self): + m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() + return m_in_2d_global_load.get_issues() + + def __call__(self): + s = self.outer.sgpr + v = self.outer.vgpr + + m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() + s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.outer.get_symbol_global_load_s_stride_d0_d1() + with self._deferred_context(): + self._emit(f"; load input") + if self.outer.tunable.nxe != 0: + self._emit(f".v_clear_nc {v.v_gld_b()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") + self._emit(f"v_cmp_eq_u32 vcc, 1, v[{v.v_in_flag()}]") + self._emit(f"s_and_saveexec_b64 s[{s.s_tmp(4)}:{s.s_tmp(5)}], vcc") + if self.outer.tunable.precache_soffset: + self._emit(m_in_2d_global_load(v.v_gld_b(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset())) + else: + self._emit(m_in_2d_global_load(v.v_gld_b(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_tmp())) + if self.outer.tunable.nxe != 0: + self._emit(f"s_or_b64 exec, exec, s[{s.s_tmp(4)}:{s.s_tmp(5)}]") + return self._get_deferred() + + # def is_1d_move_slice_k(self): + # ''' + # this now only meaning for input tensor + # ''' + # na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + # if self.tunable.nxe != 0: + # return False # if not nxe 0, it is possible that we can do move slice, but that will lead to extra index calculation + # if nb_c1e != 1 and nb_c0 == 1: + # return True + # # it is meanless to let n_c1e==1 and n_c0!=1 + # return False + + class global_load_wei_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + def get_issues(self): + m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() + return m_wei_2d_global_load.get_issues() + + def __call__(self): + s = self.outer.sgpr + v = self.outer.vgpr + + m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() + s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.outer.get_symbol_global_load_s_stride_d0_d1() + with self._deferred_context(): + self._emit(f"; load weight") + # self._emit(f".v_clear_nc {v.v_gld_a()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") + if self.outer.tunable.precache_soffset: + self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset())) + else: + self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_tmp())) + return self._get_deferred() + + class shared_store_in_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + def get_issues(self): + m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() + return m_in_2d_shared_store.get_issues() + + def __call__(self): + s = self.outer.sgpr + v = self.outer.vgpr + m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() + with self._deferred_context(): + self._emit(m_in_2d_shared_store(v.v_gld_b(), v.v_sst_b_os())) + return self._get_deferred() + + class shared_store_wei_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + def get_issues(self): + m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() + return m_wei_2d_shared_store.get_issues() + + def __call__(self): + s = self.outer.sgpr + v = self.outer.vgpr + m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() + with self._deferred_context(): + self._emit(m_wei_2d_shared_store(v.v_gld_a(), v.v_sst_a_os())) + return self._get_deferred() + + class kernel_karg_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + self.k_p_in = sym_t('k_p_in' ,0) + self.k_p_wei = sym_t('k_p_wei' ,8) + self.k_p_out = sym_t('k_p_out' ,16) + self.k_hi = sym_t('k_hi' ,24) + self.k_wi = sym_t('k_wi' ,28) + self.k_n = sym_t('k_n' ,32) + self.k_k = sym_t('k_k' ,36) + self.k_c = sym_t('k_c' ,40) + self.k_ho = sym_t('k_ho' ,44) + self.k_wo = sym_t('k_wo' ,48) + self.k_stride_h = sym_t('k_stride_h' ,52) + self.k_stride_w = sym_t('k_stride_w' ,56) + self.k_dilation_h = sym_t('k_dilation_h' ,60) + self.k_dilation_w = sym_t('k_dilation_w' ,64) + self.k_pad_h = sym_t('k_pad_h' ,68) + self.k_pad_w = sym_t('k_pad_w' ,72) + self.k_y = sym_t('k_y' ,76) + self.k_x = sym_t('k_x' ,80) + self.k_group = sym_t('k_group' ,84) + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self.k_magic_0 = sym_t('k_magic_0' ,88) + self.k_magic_1 = sym_t('k_magic_1' ,92) + self.k_magic_2 = sym_t('k_magic_2' ,96) + self.k_magic_3 = sym_t('k_magic_3' ,100) + self.k_magic_4 = sym_t('k_magic_4' ,104) + self.k_magic_5 = sym_t('k_magic_5' ,108) + self.k_magic_6 = sym_t('k_magic_6' ,112) + self.k_shift_pack_0 = sym_t('k_shift_pack_0' ,116) + self.k_shift_pack_1 = sym_t('k_shift_pack_1' ,120) + self.k__pack_0 = sym_t('k__pack_0' ,124) + self.k_end = sym_t('k_end' ,128) + else: + self.k_end = sym_t('k_end' ,88) + + def get_count(self): + return self.k_end.value + + def emit(self): + for k, v in self.__dict__.items(): + if k.startswith('k_'): + self._emit(v.declare()) + + class kernel_sgpr_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + ta_n0, ta_n1b, ta_e, ta_c, tb_k = outer.get_thread_lengths() + sseq = gpr_sequencer_t() + self.outer = outer + self.s_ka = sym_t('s_ka' , sseq(2)) + self.s_bx = sym_t('s_bx' , sseq(2)) + self.s_p_in = sym_t('s_p_in' , sseq(4)) + self.s_p_wei = sym_t('s_p_wei' , sseq(4)) + self.s_p_out = sym_t('s_p_out' , sseq(4)) + self.s_hi = sym_t('s_hi' , sseq(1)) + self.s_wi = sym_t('s_wi' , sseq(1)) + self.s_n = sym_t('s_n' , sseq(1)) + self.s_k = sym_t('s_k' , sseq(1)) # this is indeed k_per_group + self.s_c = sym_t('s_c' , sseq(1)) # this is indeed c_per_group + if outer.tunable.nxe != 0: + self.s_ho = sym_t('s_ho' , sseq(1)) + self.s_wo = sym_t('s_wo' , sseq(1)) + self.s_stride_h = sym_t('s_stride_h' , sseq(1)) + self.s_stride_w = sym_t('s_stride_w' , sseq(1)) + self.s_dilation_h = sym_t('s_dilation_h' , sseq(1)) + self.s_dilation_w = sym_t('s_dilation_w' , sseq(1)) + self.s_pad_h = sym_t('s_pad_h' , sseq(1)) + self.s_pad_w = sym_t('s_pad_w' , sseq(1)) + self.s_y = sym_t('s_y' , sseq(1)) + self.s_x = sym_t('s_x' , sseq(1)) + self.s_group = sym_t('s_group' , sseq(1)) + + # stride for in + self.s_in_stride_hi = sym_t('s_in_stride_hi' , sseq(1)) + self.s_in_stride_wi = sym_t('s_in_stride_wi' , sseq(1)) + self.s_in_stride_n = sym_t('s_in_stride_n' , sseq(1)) + if ta_n0 != 1: + self.s_in_stride_n0 = sym_t('s_in_stride_n0' , sseq(1)) + + # stride for wei + self.s_wei_stride_k = sym_t('s_wei_stride_k' , sseq(1)) + if outer.tunable.nxe != 0: + self.s_wei_stride_y = sym_t('s_wei_stride_y' , sseq(1)) + self.s_stride_c = sym_t('s_stride_c' , sseq(1)) + self.s_in_stride_c = sym_t("s_in_stride_c" , self.s_stride_c.value) + self.s_wei_stride_c = sym_t("s_wei_stride_c" , self.s_stride_c.value) + + # stride for out + self.s_out_stride_ho = sym_t('s_out_stride_ho' , sseq(1)) + self.s_out_stride_n = sym_t('s_out_stride_n' , sseq(1)) + if ta_n0 != 1: + self.s_out_stride_n0 = sym_t('s_out_stride_n0' , sseq(1)) + + + + self.s_in_stride_c_c1 = sym_t("s_in_stride_c_c1" , sseq(1)) + self.s_in_stride_c_c0_c1_diff = sym_t("s_in_stride_c_c0_c1_diff" , sseq(1)) + + self.s_block_gtc_ig = sym_t("s_block_gtc_ig" , sseq(1)) + self.s_block_gtc_ik = sym_t("s_block_gtc_ik" , sseq(1)) + self.s_block_gtc_in0 = sym_t("s_block_gtc_in0" , sseq(1)) + self.s_block_gtc_in1b = sym_t("s_block_gtc_in1b" , sseq(1)) + + self.s_move_slice_k_c1e = sym_t("s_move_slice_k_c1e" , sseq(1)) + if outer.tunable.nxe != 0: + self.s_move_slice_k_c1 = sym_t("s_move_slice_k_c1" , sseq(1)) + self.s_move_slice_k_y = sym_t("s_move_slice_k_y" , sseq(1)) + self.s_move_slice_k_x = sym_t("s_move_slice_k_x" , self.s_block_gtc_ig.value) + + self.s_knum = sym_t("s_knum" , 3) + self.s_gemm_k_num_c1 = sym_t("s_gemm_k_num_c1" , sseq(1)) + if outer.tunable.nxe != 0: + self.s_gemm_k_num_y = sym_t("s_gemm_k_num_y" , self.s_y.value) + self.s_gemm_k_num_x = sym_t("s_gemm_k_num_x" , self.s_x.value) + + #if outer.tunable.nxe != 0: + self.s_dim_b = sym_t("s_dim_b" , sseq(1)) + + self.s_kitr = sym_t("s_kitr" , 1) + if outer.tunable.precache_soffset: + m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() + in_npc = m_in_2d_global_load.get_num_precache_soffset() + wei_npc = m_wei_2d_global_load.get_num_precache_soffset() + self.s_in_offset = sym_t("s_in_offset" ,sseq(in_npc)) # if this number is zero, it is also OK, since we would not use + self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) + self.s_k_padded = sym_t("s_k_padded" ,sseq(1)) + + # TODO: this sgpr allocation is a mess + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + # allocate several sgpr to hold magic/shift value. + self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_p_out.value + 2) + self.s_shift_pack_1 = sym_t("s_shift_pack_1" ,self.s_p_out.value + 3) + + self.s_magic_2 = sym_t("s_magic_2" ,self.s_in_stride_c_c1.value) # when load, loadx4 with magic_0/1 + self.s_magic_3 = sym_t("s_magic_3" ,self.s_in_stride_c_c0_c1_diff.value) # when load, loadx4 with magic_0/1 + + self.s_magic_4 = sym_t("s_magic_4" ,self.s_move_slice_k_c1e.value) + self.s_magic_5 = sym_t("s_magic_5" ,self.s_gemm_k_num_c1.value) + self.s_magic_6 = sym_t("s_magic_6" ,self.s_block_gtc_in0.value) + + self.s_tmp = sym_t("s_tmp" ,sseq(6, 2)) + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self.s_magic_0 = sym_t("s_magic_0" ,self.s_p_wei.value + 2) + self.s_magic_1 = sym_t("s_magic_1" ,self.s_p_wei.value + 3) + + self.s_end = sym_t("s_end" ,sseq()) + + def get_count(self): + return self.s_end.value + + def emit(self): + assert self.s_end.value <= amdgpu_sgpr_limit(self.mc.arch_config.arch), f"s_end:{self.s_end.value}, tunable:{self.outer.tunable.serialize()}" + for k, v in self.__dict__.items(): + if k.startswith('s_'): + self._emit(v.declare()) + + class kernel_vgpr_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + self.outer = outer + ta_n0, ta_n1b, ta_e, ta_c, tb_k = outer.get_thread_lengths() + ca_n0, ca_n1b, ca_e, ca_c, cb_k = outer.get_cluster_lengths() + + is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS + vseq = gpr_sequencer_t() + if is_vgpr_acc_c: + self.v_c = sym_t("v_c" ,vseq(outer.tunable.num_vgpr_accumulate_c)) + v_c_num = vseq() + else: + v_c_resuable_num = outer.tunable.num_vgpr_accumulate_a + outer.tunable.num_vgpr_accumulate_b + \ + outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ + 16 # from v_sst_a_os to v_co_sst + v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups + v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 + + v_c_needed = v_c_needed if v_c_needed > 2 else 2 # let at least 2 + self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") + + self.v_a = sym_t("v_a" ,vseq(outer.tunable.num_vgpr_accumulate_a)) + self.v_b = sym_t("v_b" ,vseq(outer.tunable.num_vgpr_accumulate_b)) + self.v_gld_a = sym_t("v_gld_a" ,vseq(outer.tunable.num_vgpr_global_load_a)) + self.v_gld_b = sym_t("v_gld_b" ,vseq(outer.tunable.num_vgpr_global_load_b)) + self.v_sst_a_os = sym_t("v_sst_a_os" ,vseq(1)) + self.v_sst_b_os = sym_t("v_sst_b_os" ,vseq(1)) + self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) + self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) + self.v_in_os = sym_t("v_in_os" ,vseq(1)) + self.v_in_os_base = sym_t("v_in_os_base" ,vseq(1)) + if outer.tunable.nxe != 0: + self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) + self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) + + self.v_gtc_ta_ic = sym_t("v_gtc_ta_ic" ,vseq(1)) + if ca_n0 != 1: + self.v_gtc_ta_in0 = sym_t("v_gtc_ta_in0" ,vseq(1)) + self.v_gtc_ta_in1b = sym_t("v_gtc_ta_in1b" ,vseq(1)) + self.v_gtc_ta_in1 = sym_t("v_gtc_ta_in1" ,vseq(1)) + + self.v_gtc_tb_ik = sym_t("v_gtc_tb_ik" ,vseq(1)) + + self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) + self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) + + self.v_out_os = sym_t("v_out_os" ,vseq(1)) + if outer.tunable.nxe != 0: + self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) + self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) + self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) + self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) + + self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) + self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) + self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) + self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) + if outer.tunable.nxe != 0: + self.v_in_iy = sym_t("v_in_iy" ,vseq(1)) + self.v_in_ix = sym_t("v_in_ix" ,vseq(1)) + + self.v_move_slice_k_ic1 = sym_t("v_move_slice_k_ic1" , self.v_gtc_ta_ic.value) + if outer.tunable.nxe != 0: + self.v_move_slice_k_iy = sym_t("v_move_slice_k_iy", self.v_in_iy.value) + self.v_move_slice_k_ix = sym_t("v_move_slice_k_ix", self.v_in_ix.value) + + self.v_gemm_in = sym_t("v_gemm_in" , vseq(1)) + self.v_gemm_im = sym_t("v_gemm_im" , vseq(1)) + + self.v_out_iho = sym_t("v_out_iho" ,vseq(1)) + self.v_out_iwo = sym_t("v_out_iwo" ,vseq(1)) + self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) + self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) + + self.v_cur_k = sym_t("v_cur_k" ,vseq(1)) + + self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + total_vgpr = vseq() + if outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + # if xdlops agpr is larger than vgpr usage, must change vgpr count to agpr + total_vgpr = max(total_vgpr, outer.tunable.num_agpr_accumulate_c) + self.v_end = sym_t("v_end" ,total_vgpr) + + def get_count(self): + return self.v_end.value + + def emit(self): + for k, v in self.__dict__.items(): + if k.startswith('v_'): + self._emit(v.declare()) + + class kernel_agpr_t(mc_base_t): + def __init__(self, mc, outer): + mc_base_t.__init__(self, mc) + assert outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS, 'only xdlops can use agpr' + self.outer = outer + aseq = gpr_sequencer_t() + self.a_c = sym_t("a_c", aseq(outer.tunable.num_agpr_accumulate_c)) + self.a_end = sym_t("a_end", aseq()) + + def get_count(self): + return self.a_end.value + + def emit(self): + for k, v in self.__dict__.items(): + if k.startswith('a_'): + self._emit(v.declare()) + + def get_thread_lengths(self): + t_ta = self.tunable.tensor_a_thread_lengths + t_tb = self.tunable.tensor_b_thread_lengths + + assert len(t_ta) == 4 and len(t_tb) == 3 + + ta_e, ta_c, ta_n0, ta_n1b = t_ta[0], t_ta[1], t_ta[2], t_ta[3] + tb_e, tb_c, tb_k = t_tb[0], t_tb[1], t_tb[2] + + assert ta_e == tb_e and ta_c == tb_c + + assert ta_e == 1, "currently not support >1 in e dimension" + + if self.tunable.nxe == 0: + #assert ta_c0 == 1 + #assert tb_c0 == 1 + pass + else: + pass + + return ta_n0, ta_n1b, ta_e, ta_c, tb_k # M, N, K + + def get_cluster_lengths(self): + c_ta = self.tunable.tensor_a_cluster_lengths + c_tb = self.tunable.tensor_b_cluster_lengths + + assert len(c_ta) == 4 and len(c_tb) == 3 + + ca_e, ca_c, ca_n0, ca_n1b = c_ta[0], c_ta[1], c_ta[2], c_ta[3] + cb_e, cb_c, cb_k = c_tb[0], c_tb[1], c_tb[2] + + assert ca_e == cb_e and ca_c == cb_c + + assert ca_e == 1 and ca_n0 == 1 + + return ca_n0, ca_n1b, ca_e, ca_c, cb_k # M, N, K + + def get_dims_lengths(self): + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() + + na_n0, na_n1b, na_e, na_c = ta_n0 * ca_n0, ta_n1b * ca_n1b, ta_e * ca_e, ta_c * ca_c + nb_k = tb_k * cb_k + + return na_n0, na_n1b, na_e, na_c, nb_k # M, N, K + + def get_thread_copy_dims(self): + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + in_thread_copy_dims = [ta_n0, ta_n1b, ta_e, ta_c] + wei_thread_copy_dims = [tb_k, ta_e, ta_c] # always reordered! + return in_thread_copy_dims, wei_thread_copy_dims + + def get_thread_copy_index(self): + in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + in_thread_copy_index = _find_non_1_index_in_list(in_thread_copy_dims) + wei_thread_copy_index = _find_non_1_index_in_list(wei_thread_copy_dims) + + ''' + if thread lengths both dimension is 1, means every thread only copy one pixel. + we need support this also + ''' + return in_thread_copy_index, wei_thread_copy_index + + def get_macro_global_load(self): + ''' + NOTICE: input/wei always load gemm_k (e*c) first. indeed always load c, and do vector load if possible + ''' + inline = True if self.tunable.fma_interleave else False + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + + in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + ctrl_wei_gld = ctrl_2d_global_load_t() + ctrl_in_gld = ctrl_2d_global_load_t() + + ctrl_wei_gld.vector_d1 = utility_gcd(ta_c, 4) if ta_c != 1 else 1 + ctrl_in_gld.vector_d1 = utility_gcd(ta_c, 4) if ta_c != 1 else 1 + + if self.wei_thread_copy_ndim == 2: + ctrl_wei_gld.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] + elif self.wei_thread_copy_ndim == 1: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + else: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] + + if self.in_thread_copy_ndim == 2: + ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] + elif self.in_thread_copy_ndim == 1: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + else: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] + + if self.tunable.precache_soffset: + return macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ + macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_in_gld, inline) + else: + return macro_igemm_2d_global_load_t(self.mc, ctrl_wei_gld, inline), macro_igemm_2d_global_load_t(self.mc, ctrl_in_gld, inline) + + + def get_macro_shared_store(self): + in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + + ## give the LDS strides of wei dimensions [ta_k0, ta_k1, ta_c0, ta_c1e] + #if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K0_K1: + # wei_stride_list = [na_k1, 1, na_c1e*na_k0*na_k1, na_k0*na_k1] + #else: + # wei_stride_list = [1, na_k0, na_c1e*na_k0*na_k1, na_k0*na_k1] + + ## give the LDS strides of in dimensions [tb_c0, tb_c1e, tb_n0, tb_n1b] + #if gemm_n_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_N_N0_N1B: + # in_stride_list = [nb_c1e*nb_n0*nb_n1b, nb_n0*nb_n1b, nb_n1b, 1] + #else: + # in_stride_list = [nb_c1e*nb_n0*nb_n1b, nb_n0*nb_n1b, 1, nb_n0] + + # [ta_n0, ta_n1b, ta_e, ta_c] + if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: + in_stride_list = [na_n1b, 1, na_c*na_n0*na_n1b, na_n0*na_n1b] + else: + in_stride_list = [1, na_n0, na_c*na_n0*na_n1b, na_n0*na_n1b] + + # [tb_k, ta_e, ta_c] + wei_stride_list = [1, nb_k*na_c, nb_k] + + in_sst_ctrl = ctrl_2d_shared_store_t() + in_sst_ctrl.src_order = 1 + in_sst_ctrl.v_tmp = self.vgpr.v_tmp + + wei_sst_ctrl = ctrl_2d_shared_store_t() + wei_sst_ctrl.src_order = 1 + wei_sst_ctrl.v_tmp = self.vgpr.v_tmp + + # [ta_n0, ta_n1b, ta_e, ta_c] + if self.in_thread_copy_ndim == 2: + if in_thread_copy_index[0] in (0, 1) and in_thread_copy_index[1] in (2, 3): + in_sst_ctrl.length_d0 = in_thread_copy_dims[in_thread_copy_index[1]] + in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + in_sst_ctrl.stride_d0 = in_stride_list[in_thread_copy_index[1]] * data_byte + in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[0]] * data_byte + else: + in_sst_ctrl.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] + in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] + in_sst_ctrl.stride_d0 = in_stride_list[in_thread_copy_index[0]] * data_byte + in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[1]] * data_byte + if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: + in_sst_ctrl.vector_d1 = ta_n1b + else: + in_sst_ctrl.vector_d1 = in_thread_copy_dims[in_thread_copy_index[1]] + + elif self.in_thread_copy_ndim == 1: + in_sst_ctrl.length_d0 = 1 + in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + if (gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B and ta_n1b != 1) or \ + (gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 and ta_n0 != 1): + in_sst_ctrl.vector_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + else: + in_sst_ctrl.vector_d1 = 1 + in_sst_ctrl.stride_d0 = 1 + in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[0]] * data_byte + if in_sst_ctrl.length_d1 == 8 and in_sst_ctrl.vector_d1 != 1: + # assert False + # TODO: this is indeed not optimal. may consider shuffle in the future. + in_sst_ctrl.length_d0 = 2 + in_sst_ctrl.length_d1 = 4 + in_sst_ctrl.vector_d1 = 4 + in_sst_ctrl.stride_d0 = 4 * data_byte + else: + assert False + + # [tb_k, ta_e, ta_c] + if self.wei_thread_copy_ndim == 2: + if wei_thread_copy_index[0] in (0,) and wei_thread_copy_index[1] in (1, 2): + # when store into LDS, reorder back. indeed we always wish this pattern, if ndim is 2 + wei_sst_ctrl.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[1]] + wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + wei_sst_ctrl.stride_d0 = wei_stride_list[wei_thread_copy_index[1]] * data_byte + wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[0]] * data_byte + else: + wei_sst_ctrl.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] + wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] + wei_sst_ctrl.stride_d0 = wei_stride_list[wei_thread_copy_index[0]] * data_byte + wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[1]] * data_byte + wei_sst_ctrl.need_transpose = 0 + wei_sst_ctrl.vector_d1 = tb_k + + elif self.wei_thread_copy_ndim == 1: + wei_sst_ctrl.length_d0 = 1 + wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + + if ta_c != 1: + wei_sst_ctrl.vector_d1 = utility_gcd(wei_thread_copy_dims[wei_thread_copy_index[0]], 4) + else: + wei_sst_ctrl.vector_d1 = 1 + + wei_sst_ctrl.stride_d0 = 1 + wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[0]] * data_byte + if wei_sst_ctrl.length_d1 == 8 and wei_sst_ctrl.vector_d1 != 1: + # assert False + # TODO: this is indeed not optimal. may consider shuffle in the future. + wei_sst_ctrl.length_d0 = 2 + wei_sst_ctrl.length_d1 = 4 + wei_sst_ctrl.vector_d1 = 4 + wei_sst_ctrl.stride_d0 = 4 * data_byte + else: + assert False + + # print(f"in_sst_ctrl.vector_d1:{in_sst_ctrl.vector_d1}, wei_sst_ctrl.vector_d1:{wei_sst_ctrl.vector_d1}") + # print(f"wei_sst_ctrl, {wei_sst_ctrl.serialize()}") + inline = True if self.tunable.fma_interleave else False + return macro_igemm_2d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_2d_shared_store_t(self.mc, wei_sst_ctrl, inline) + + # computation macro + def get_macro_in_update_hw(self): + inline = True if self.tunable.fma_interleave else False + return self.macro_in_update_hw_t(self.mc, inline) + + def get_macro_in_update_os(self): + inline = True if self.tunable.fma_interleave else False + return self.macro_in_update_os_t(self.mc, amdgpu_precision_data_byte(self.tunable.precision), inline) + + def get_macro_move_slice_window(self): + inline = True if self.tunable.fma_interleave else False + move_slice_window = self.macro_move_slice_window_k_e1_c_t(self.mc, self.tunable, inline) + + # return single functor ! + return move_slice_window + + + def get_macro_set_flag_hw(self): + inline = True if self.tunable.fma_interleave else False + return self.macro_set_flag_hw(self.mc, inline) + + def get_symbol_global_load_s_stride_d0_d1(self): + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + # get the symbol object that load 2d may use + s = self.sgpr + s_dummy = sym_t("s_dummy") + in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + + # [ta_n0, ta_n1b, ta_e, ta_c] + in_stride_gprs = [s.s_in_stride_n0 if ta_n0 != 1 else s_dummy, + s.s_in_stride_wi, + s_dummy, + s.s_in_stride_c] + + # [tb_k, ta_e, ta_c] + wei_stride_gprs = [s.s_wei_stride_k, + s_dummy, + s.s_wei_stride_c] + + if self.in_thread_copy_ndim == 2: + s_in_stride_d0 = in_stride_gprs[in_thread_copy_index[0]] + s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[1]] + elif self.in_thread_copy_ndim == 1: + s_in_stride_d0 = s_dummy + s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[0]] + else: + s_in_stride_d0 = s_dummy + s_in_stride_d1 = in_stride_gprs[-1] + + if self.wei_thread_copy_ndim == 2: + # print(f" ____ wei_thread_copy_index:{len(wei_thread_copy_index)}, {wei_thread_copy_index}") + s_wei_stride_d0 = wei_stride_gprs[wei_thread_copy_index[0]] + s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[1]] + elif self.wei_thread_copy_ndim == 1: + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[0]] + else: + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = wei_stride_gprs[-1] + + return s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 + + + def get_kernel_code(self): + kernel_code = amdgpu_kernel_code_t({ + 'enable_sgpr_kernarg_segment_ptr' : 1, + 'enable_sgpr_workgroup_id_x' : 1, + 'enable_vgpr_workitem_id' : 0, + 'workgroup_group_segment_byte_size' : self.tunable.lds_total, + 'kernarg_segment_byte_size' : self.karg.get_count(), + 'wavefront_sgpr_count' : self.sgpr.get_count() + 2*3, + 'workitem_vgpr_count' : self.vgpr.get_count() + }) + return kernel_code + + def get_kernel_args(self): + ''' + float *p_in; + float *p_wei; + float *p_out; + int hi; + int wi; + int n; + int k; + int c; + int ho; + int wo; + int stride_h; + int stride_w; + int dilation_h; + int dilation_w; + int pad_h; + int pad_w; + int y; + int x; + int group; + /* if use magic division */ + uint32_t magic_0; // denom: sa=0: n*b / n_per_block, sa=1: k / m_per_block + uint32_t magic_1; // denom: ((n / nb_n0) * b) / nb_n1b + uint32_t magic_2; // denom: y*x, if nxe==0 not used + uint32_t magic_3; // denom: x, if nxe==0 not used + uint32_t magic_4; // denom: b + uint32_t magic_5; // denom: wo + uint32_t magic_6; // denom: n*b*k / (m_per_block*n_per_block) + uint32_t shift_pack_0; + uint32_t shift_pack_1; + uint32_t __pack_0; + ''' + kas = [] + # name: {}, .size: {}, .offset: {}, .value_kind: {}, .value_type + kas.append(amdgpu_kernel_arg_t('p_in' , 8, 0, 'global_buffer','f32',address_space='global',is_const='true')) + kas.append(amdgpu_kernel_arg_t('p_wei' , 8, 8, 'global_buffer','f32',address_space='global',is_const='true')) + kas.append(amdgpu_kernel_arg_t('p_out' , 8, 16, 'global_buffer','f32',address_space='global',is_const='false')) + kas.append(amdgpu_kernel_arg_t('hi' , 4, 24, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('wi' , 4, 28, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('n' , 4, 32, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('k' , 4, 36, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('c' , 4, 40, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('ho' , 4, 44, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('wo' , 4, 48, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('stride_h' , 4, 52, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('stride_w' , 4, 56, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('dilation_h' , 4, 60, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('dilation_w' , 4, 64, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('pad_h' , 4, 68, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('pad_w' , 4, 72, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('y' , 4, 76, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('x' , 4, 80, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('group' , 4, 84, 'by_value','i32')) + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + kas.append(amdgpu_kernel_arg_t('magic_0' , 4, 88, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_1' , 4, 92, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_2' , 4, 96, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_3' , 4, 100, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_4' , 4, 104, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_5' , 4, 108, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('magic_6' , 4, 112, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('shift_pack_0' , 4, 116, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('shift_pack_1' , 4, 120, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('__pack_0' , 4, 124, 'by_value','i32')) + else: + pass + return kas + + def get_kernel_info(self): + kernel_code = self.get_kernel_code() + kernel_args = self.get_kernel_args() + kernel_info = amdgpu_kernel_info_t(kernel_code, self.name(), self.tunable.block_size, kernel_args) + return kernel_info + + def get_kernel_macros(self): + kernel_macros = [] + for attrs in dir(self): + if attrs.startswith('get_macro_'): + functor = getattr(self, attrs) + rtn = functor() + if rtn is None: + continue + + # here we follow the convention in code: + # #1. for macro like emit class, use emit() to generate macro definition, use __call__() to call this macro + # #2. for non-macro like emit class, which might want to "inline-ed" into normal code, no emit() is defined, just __call__(). + # hence need to check if has attr name "emit". if not have, it is type #2, no need to do emit() before hand. + if type(rtn) is tuple: + for e in rtn: + #if hasattr(e, 'emit'): + if not e.is_inline(): + #continue + kernel_macros.extend([m for m in rtn]) + else: + #if hasattr(rtn, 'emit'): + if not e.is_inline(): + #continue + kernel_macros.append(rtn) + return kernel_macros + + def emit_kernel_prologue(self): + s = self.sgpr + v = self.vgpr + k = self.karg + gemm_m_unmerge_cluster = self.tunable.gemm_m_unmerge_cluster + gemm_n_unmerge_cluster = self.tunable.gemm_n_unmerge_cluster + gemm_k_unmerge_cluster = self.tunable.gemm_k_unmerge_cluster + + assert gemm_m_unmerge_cluster == 0 and gemm_n_unmerge_cluster == 0 and gemm_k_unmerge_cluster == 0, 'in fwd nhwc, gemm_m/n/k unmerge_cluster only support 0' + + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() + na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + + unmerge_sub_n = self.tunable.unmerge_sub_n + if gemm_n_unmerge_cluster == 0: + assert unmerge_sub_n % na_n0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, na_n0:{na_n0}" + unmerge_sub_n1 = unmerge_sub_n // na_n0 + assert na_n1b % unmerge_sub_n1 == 0, f"na_n1b:{na_n1b}, unmerge_sub_n1:{unmerge_sub_n1}" + + else: + assert False, f"unsupported gemm_n_unmerge_cluster:{self.tunable.gemm_n_unmerge_cluster}" + + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + m_in_update_hw = self.get_macro_in_update_hw() + m_in_update_os = self.get_macro_in_update_os() + # m_wei_update_os = self.get_macro_wei_update_os() + # m_wei_update_yx = self.get_macro_wei_update_yx() + m_set_flag_hw = self.get_macro_set_flag_hw() + s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.get_symbol_global_load_s_stride_d0_d1() + + m_wei_2d_global_load, m_in_2d_global_load = self.get_macro_global_load() + + tc_index_dispatcher = igemm_thread_cluster_index_dispatcher_t(self.mc) + tc_index_accumulator = igemm_thread_cluster_index_accumulator_t(self.mc) + + m_int_div_rem_vv = macro_int_div_rem_vv_t(self.mc) + m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) + m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) + + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + m_mdiv_u32_vs = macro_mdiv_u32_rem_vs_t(self.mc) + m_mdiv_u32_ss = macro_mdiv_u32_rem_ss_t(self.mc) + else: + m_int_div_rem_vv = macro_int_div_rem_vv_t(self.mc) + m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) + m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) + + gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + s_dummy = sym_t("s_dummy") + + # start emit + self._emit(f"s_load_dwordx2 s[{s.s_p_in((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_in()}") + self._emit(f"s_load_dwordx2 s[{s.s_p_wei((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_wei()}") + self._emit(f"s_load_dwordx2 s[{s.s_p_out((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_out()}") + if self.tunable.nxe != 0: + self._emit(f"s_load_dwordx8 s[{s.s_hi((0, 7))}], s[{s.s_ka((0, 1))}], 0+{k.k_hi()}") + self._emit(f"s_load_dwordx8 s[{s.s_stride_w((0, 7))}], s[{s.s_ka((0, 1))}], 0+{k.k_stride_w()}") + else: + self._emit(f"s_load_dwordx4 s[{s.s_hi((0, 3))}], s[{s.s_ka((0, 1))}], 0+{k.k_hi()}") + self._emit(f"s_load_dword s[{s.s_c()}], s[{s.s_ka((0, 1))}], 0+{k.k_c()}") + self._emit(f"s_load_dword s[{s.s_group()}], s[{s.s_ka((0, 1))}], 0+{k.k_group()}") + + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_load_dwordx2 s[{s.s_magic_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_0()}") + self._emit(f"s_load_dwordx2 s[{s.s_tmp((2, 3))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") + self._emit(f"s_load_dwordx2 s[{s.s_tmp((4, 5))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_4()}") + self._emit(f"s_load_dword s[{s.s_magic_6()}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_6()}") + self._emit(f"s_load_dwordx2 s[{s.s_shift_pack_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") + + self._emit(f"; in(e, c, n0, n1b) thread_lengths: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, cluster_length: {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}") + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") + self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) + if ca_n0 != 1: + # TODO: this is not wanted + self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_n1b, ta_n1b)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_in0(), v.v_tmp(), ca_n0, ta_n0, True)) + else: + self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_n1b, ta_n1b, True)) + + self._emit(f"; wei(e, c, k) thread_length: {ta_e}x{ta_c}x{tb_k}, cluster_length: {ca_e}x{ca_c}x{cb_k}") + # weight ic same as input + self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") + self._emit(tc_index_dispatcher(v.v_gtc_tb_ik(), v.v_tmp(), cb_k, tb_k, True)) + self._emit_empty_line() + + self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") + + self._emit(f"s_waitcnt lgkmcnt(0)") + self._emit_empty_line() + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_mov_b32 s[{s.s_magic_2()}], s[{s.s_tmp(2)}]") + self._emit(f"s_mov_b32 s[{s.s_magic_3()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mov_b32 s[{s.s_magic_4()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mov_b32 s[{s.s_magic_5()}], s[{s.s_tmp(5)}]") + self._emit(f"; calculate index") + + # calculate stride, not shift data byte yet + if self.tunable.nxe != 0: + # input + self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") + if ta_n0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") + # weight + self._emit(f"s_mul_i32 s[{s.s_wei_stride_y()}], s[{s.s_x()}], s[{s.s_c()}]") + self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_wei_stride_y()}], s[{s.s_y()}]") + # output + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_ho()}], s[{s.s_wo()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho()}], s[{s.s_out_stride_ho()}]") + if ta_n0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") + + else: + # input + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") + if ta_n0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") + # weight + self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") + # output + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_ho()}], s[{s.s_wi()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_hi()}], s[{s.s_out_stride_ho()}]") + if ta_n0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") + + # early init s_knum in case shifted + #if self.tunable.nxe != 0: + self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}]") + #else: + # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") + + # warp around the really dim_b length, in case pad + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_ho()}], s[{s.s_wo()}]") + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_tmp(4)}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + else: + self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") # no pad + + # for gemm_m pad + # self._emit_empty_line() + # self._emit(f"; pad k if need") + # self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_k()}]") + # self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + # self._emit(f"s_lshl_b32 s[{s.s_k_padded()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + + self._emit_empty_line() + self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") + + # calculate group index TODO: use blockIdx.y as group index + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_dim_b()}], s[{s.s_n()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_k()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), s.s_magic_6(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + # s.s_tmp(4)=> rem, gemm_m, gemm_n, s.s_block_gtc_ig()=> quo, group + self._emit(f"s_mov_b32 s[{s.s_bx()}], s[{s.s_tmp(4)}]") + + if self.tunable.source_access_order == IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N: + self._emit(f"s_lshr_b32 s[0], s[{s.s_k()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + else: + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_dim_b()}], s[{s.s_n()}]") + + self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + self._emit(f"; s_tmp+4:block_gtc_in, s_tmp+5:block_gtc_im") + + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + + if unmerge_sub_n1 == 1: + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_n1b)} ; total number of n1b") + else: + if unmerge_sub_n1 == na_n1b: + self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of n1b") + else: + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_n1b // unmerge_sub_n1)} ; total number of n1b") + + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(4), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(4), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + + if na_n1b != 1: + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_n1b)}") + if na_n0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_n0)}") + self._emit_empty_line() + + self._emit(f"; in n1b transform") + if ca_n1b == 1: + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}]") + else: + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_ta_in1b()}]") + if self.tunable.nxe != 0: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + self._emit(f"v_mul_lo_u32 v[{v.v_in_iho()}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") + self._emit(f"v_sub_i32 v[{v.v_in_iho()}], v[{v.v_in_iho()}], s[{s.s_pad_h()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_iwo()}], s[{s.s_stride_w()}], v[{v.v_in_iwo()}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwo()}], v[{v.v_in_iwo()}], s[{s.s_pad_w()}]") + self._emit(m_in_update_hw(v.v_in_ihi(), v.v_in_iwi(), v.v_in_iho(), v.v_in_iwo(), v.v_in_iy(), v.v_in_ix(), s.s_dilation_h(), s.s_dilation_w())) + self._emit_empty_line() + else: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + + self._emit(f"; calculate in offset") + # compute group distance + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ig()}], s[{s.s_block_gtc_ig()}], {igemm_log2(data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_in_stride_n()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_in_stride_n()}], s[{s.s_tmp(3)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit_empty_line() + + #if gemm_m_unmerge_cluster == 0: + if ca_n0 != 1: + self._emit(tc_index_accumulator(v.v_tmp(1), v.v_gtc_ta_in0(), v.v_gtc_ta_in1(), ca_n0, ca_n1b, 0, unmerge_sub_n1)) + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_tmp(1)}]") + else: + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") + # else: + # # no in0 + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") + + # s_in_stride_wi need shift before! + if self.tunable.nxe != 0: + self._emit(f"v_add_lshl_u32 v[{v.v_in_os_base()}], v[{v.v_gtc_ta_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_in_os_base()}], v[{v.v_tmp()}]") + + self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) + else: + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ta_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + self._emit_empty_line() + + if self.in_thread_copy_ndim != 1: + if s_in_stride_d0 != s_dummy: + self._emit(self.try_shift_stride(s_in_stride_d0, igemm_log2(data_byte))) + if s_in_stride_d1 != s_dummy: + self._emit(self.try_shift_stride(s_in_stride_d1, igemm_log2(data_byte))) + self._emit_empty_line() + + if self.tunable.precache_soffset: + self._emit(m_in_2d_global_load.init_precache_soffset(s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset(), s.s_tmp())) + + # load in + self._emit(self.global_load_in()) + self._emit_empty_line() + self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") + # config weight range + self._emit("; config for weight range") + #self._emit(f"s_mul_i32 s[{s.s_p_wei(2)}], s[{s.s_wei_stride_k() if self.tunable.nxe != 0 else s.s_c()}], s[{s.s_k()}]") + #self._emit(f"s_lshl_b32 s[{s.s_p_wei(2)}], s[{s.s_p_wei(2)}], {igemm_log2(data_byte)}") + self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") + + self._emit(f"; calculate wei offset") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_k()}], s[{s.s_wei_stride_k()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"v_add_u32 v[{v.v_cur_k()}], s[{s.s_block_gtc_ik()}], v[{v.v_gtc_tb_ik()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_cur_k()}]") + + self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(data_byte)}") + + # self._emit(m_wei_update_os(v.v_wei_os(), v.v_wei_os_base(), v.v_wei_iy(), v.v_wei_ix(), s.s_x(), v.v_tmp())) + #else: + # self._emit(tc_index_accumulator(v.v_tmp(), v.v_gtc_ik0(), v.v_gtc_ta_ik1(), ca_k0, ca_k1, na_k0, na_k1)) + # self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_tmp()}]") + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") + # self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic1e()}], {igemm_log2(data_byte)}") + + self._emit_empty_line() + if self.wei_thread_copy_ndim != 1: + if s_wei_stride_d0 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d0, igemm_log2(data_byte))) + if s_wei_stride_d1 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d1, igemm_log2(data_byte))) + self._emit_empty_line() + + if self.tunable.precache_soffset: + self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + + self._emit(self.global_load_wei()) + self._emit_empty_line() + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.thread_mapping(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) + else: + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") + self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) + + self._emit(f"; LDS store, in: e,c,n0,n1b: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}") + if ca_n1b == 1: + # TODO: remove this path, not possible go here + assert False + else: + if ca_n0 == 1: + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in1b()}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_n1b)}, v[{v.v_gtc_ta_in1b()}]") + + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_n0*na_n1b)}, v[{v.v_tmp()}]") + #if cb_c0 != 1: + # self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_tb_ic0()}], {igemm_log2(nb_c1e*nb_n0*nb_n1b)}, v[{v.v_tmp()}]") + + self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + # self._emit(f"v_add_u32 v[{v.v_sst_a_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_a_os()}]") + self._emit_empty_line() + + self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k}, {ca_e}x{ca_c}x{cb_k}") + + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(nb_k)}, v[{v.v_gtc_tb_ik()}]") + + self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") + self._emit_empty_line() + + self._emit(f"; LDS load") + self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}]") + self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") + self._emit_empty_line() + + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self._emit(f"v_mov_b32 v[{v.v_gemm_in()}], v[{v.v_co_sst()}]") + self._emit(f"v_mov_b32 v[{v.v_gemm_im()}], v[{v.v_co_sld()}]") + self._emit(self.coalescing_store.init_co_lds_offset(v.v_co_sst(), v.v_co_sld(), v.v_gemm_im(), v.v_gemm_in(), '0', v.v_tmp())) + self._emit(self.coalescing_store.init_co_sub_m_index(v.v_co_sub_m_index(), '0', v.v_tmp())) + self._emit(self.coalescing_store.init_co_sub_n_index(v.v_co_sub_n_index(), '0', v.v_tmp())) + self._emit_empty_line() + + self._emit(f"; output offset") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + + self._emit_empty_line() + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_ik()}], {igemm_log2(data_byte)}") + + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp(3)}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out()}+1], 0") + self._emit_empty_line() + + self._emit(f"; compute v_co_sub_m_index along n0 x n1b : {na_n0}x{na_n1b}") + if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: + if na_n1b != 1: + self._emit(f"v_and_b32 v[{v.v_out_in1b()}], {na_n1b - 1}, v[{v.v_co_sub_m_index()}] ; => N1B") + if na_n0 != 1: + self._emit(f"v_lshrrev_b32 v[{v.v_out_in0()}], {igemm_log2(na_n1b)}, v[{v.v_co_sub_m_index()}] ; => N0") + else: + assert na_n0 == self.tunable.block_size + assert False, "un implemented, should rarely be used" + else: + if na_n0 != 1: + self._emit(f"v_and_b32 v[{v.v_out_in0()}], {na_n0 - 1}, v[{v.v_co_sub_m_index()}] ; => N0") + if na_n1b != 1: + self._emit(f"v_lshrrev_b32 v[{v.v_out_in1b()}], {igemm_log2(na_n0)}, v[{v.v_co_sub_m_index()}] ; => N1B") + else: + assert False, "un implemented, should rarely be used" + else: + if na_n1b != 1: + self._emit(f"v_mov_b32 v[{v.v_out_in1b()}], v[{v.v_co_sub_m_index()}] ; => N1B") + else: + assert False, "un implemented, should rarely be used" + + self._emit(f"; compute from n1b") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_out_in1b()}]") + if self.tunable.nxe != 0: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + self._emit_empty_line() + else: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + self._emit_empty_line() + self._emit_empty_line() + self._emit(f"; add in_in0, in_in1") + if na_n0 != 1: + #if gemm_m_unmerge_cluster == 0: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp(1)}], v[{v.v_out_in0()}], {igemm_log2(unmerge_sub_n1)}, v[{v.v_out_in1()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_n()}], v[{v.v_tmp(1)}]") + # else: + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_n()}], v[{v.v_out_in1()}]") + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_out_stride_n0()}], v[{v.v_out_in0()}]") + # self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_tmp()}], v[{v.v_tmp(1)}]") + else: + self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_n()}], v[{v.v_out_in1()}]") + + self._emit(f"; add i_k") + ## gemm_m_unmerge_cluster is always 0 + # if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K0_K1: + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") + # else: + # if na_k0 == 1: + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") + # else: + # if na_k1 == 1: + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") + # else: + # self._emit(f"v_and_b32 v[{v.v_tmp()}], {na_k0 - 1}, v[{v.v_co_sub_m_index()}] ; => k0") + # self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(na_k0)}, v[{v.v_co_sub_m_index()}] ; => k1") + # self._emit(f"v_lshl_or_b32 v[{v.v_tmp(1)}], v[{v.v_tmp()}], {igemm_log2(na_k1)}, v[{v.v_tmp(1)}]") + # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_tmp(1)}]") + + self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_co_sub_n_index()}]") # n, add to k + + self._emit(f"; add ho, wo") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], v[{v.v_out_iho()}]") + self._emit(f"v_add3_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp(1)}], v[{v.v_out_iwo()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_out_os()}], {igemm_log2(data_byte)}, v[{v.v_out_os()}]") + if self.tunable.nxe != 0: + self._emit(m_set_flag_hw(v.v_out_flag(), v.v_out_iho(), v.v_out_iwo(), s.s_ho(), s.s_wo())) + + self._emit(f"; move slice stride") + # assert na_c0 * na_c1e == self.tunable.gemm_k_per_block and nb_c0 * nb_c1e == self.tunable.gemm_k_per_block + + # if self.tunable.nxe != 0: + # #assert na_c0 * na_c1e == nb_c0 * nb_c1e + # self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1e()}], {na_c0 * na_c1e}") + # if IGEMM_GTC_FEAT_MAGIC_DIVISION: + # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + # self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_magic_2(), s.s_tmp(3), s.s_wei_stride_c(), s.s_tmp())) + # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") + # self._emit(m_mdiv_u32_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_magic_3(), s.s_tmp(3), s.s_x(), s.s_tmp())) + # else: + # self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_wei_stride_c(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) + # self._emit(m_int_div_rem_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_x(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) + # else: + # #assert na_c1e == nb_c1e + # #self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1()}], {nb_c1e}") + # self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1e()}], {na_c1e}") + # self._emit_empty_line() + + # m_move_slice_window_ta, m_move_slice_window_tb = self.get_macro_move_slice_window() + + # if self.tunable.nxe != 0: + # # assert s.s_out_stride_k.label not in self.dict_shifted_stride and s.s_wei_stride_k.label not in self.dict_shifted_stride + # if s.s_in_stride_c.label not in self.dict_shifted_stride: + # self._emit(m_move_slice_window_tb.init_stride_c(s.s_in_stride_c(), s.s_in_stride_c_c1(), + # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) + # else: + # self._emit(f"s_lshr_b32 s[{s.s_tmp(3)}], s[{s.s_in_stride_c()}], {utility_log2(data_byte)}") + # self._emit(m_move_slice_window_tb.init_stride_c(s.s_tmp(3), s.s_in_stride_c_c1(), + # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) + # else: + # if self.is_1d_move_slice_k(): + # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_hw(), s.s_in_stride_c_c1(), s.s_move_slice_k_c1e())) + # else: + # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_hw(), s.s_in_stride_c_c1(), + # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1e())) + + + # if not self.is_1d_move_slice_k(): + # self._emit(f"s_mov_b32 s[{s.s_gemm_k_num_c1()}], {unmerge_sub_tb_c1}") + #if self.tunable.nxe != 0: + # self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_wei_stride_c()}], s[{s.s_c()}]") + #else: + # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") + self._emit_empty_line() + + self._emit(self.try_shift_stride(s.s_in_stride_c_c1, igemm_log2(data_byte))) + #self._emit(self.try_shift_stride(s.s_wei_stride_k_k1, igemm_log2(data_byte))) + self._emit(self.try_shift_stride(s.s_in_stride_c_c0_c1_diff, igemm_log2(data_byte))) + #self._emit(self.try_shift_stride(s.s_wei_stride_k_k0_k1_diff, igemm_log2(data_byte))) + + if self.tunable.nxe != 0: + self._emit(self.try_shift_stride(s.s_in_stride_c, igemm_log2(data_byte))) + self._emit(self.try_shift_stride(s.s_wei_stride_k, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) + else: + self._emit(self.try_shift_stride(s.s_in_stride_c, igemm_log2(data_byte))) + self._emit(self.try_shift_stride(s.s_c, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) + + # self._emit(self.try_shift_stride(s.s_move_slice_k_c1e, igemm_log2(data_byte))) + self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") + + + def emit_kernel_fma_main_loop(self): + s = self.sgpr + v = self.vgpr + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + + # m_move_slice_window_ta, m_move_slice_window_tb = self.get_macro_move_slice_window() + m_move_slice_window = self.get_macro_move_slice_window() + + def move_slice_window_b(): + return '' + + # if self.tunable.nxe != 0: + # m_in_update_os = self.get_macro_in_update_os() + # m_in_update_hw = self.get_macro_in_update_hw() + # m_set_flag_hw = self.get_macro_set_flag_hw() + # with self._deferred_context(): + # self._emit(m_move_slice_window_tb(v.v_move_slice_k_ic1(), v.v_move_slice_k_iy(), v.v_move_slice_k_ix(), s.s_gemm_k_num_c1(), s.s_gemm_k_num_y(), s.s_gemm_k_num_x(), + # s.s_move_slice_k_c1(), s.s_move_slice_k_y(), s.s_move_slice_k_x(), v.v_in_os_base(), + # s.s_in_stride_c(), s.s_in_stride_c_c1(), s.s_in_stride_c_c0_c1_diff())) + # self._emit(m_in_update_hw(v.v_in_ihi(), v.v_in_iwi(), v.v_in_iho(), v.v_in_iwo(), v.v_in_iy(), v.v_in_ix(), s.s_dilation_h(), s.s_dilation_w())) + # self._emit(m_in_update_os(v.v_in_os(), v.v_in_os_base(), v.v_in_ihi(), v.v_in_iwi(), s.s_wi(), v.v_tmp())) + # self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) + # return self._get_deferred() + # else: + # with self._deferred_context(): + # if self.is_1d_move_slice_k(): + # self._emit(m_move_slice_window_tb(v.v_in_os(), s.s_move_slice_k_c1e(), s.s_in_stride_c(), s.s_in_stride_c_c1())) + # else: + # self._emit(m_move_slice_window_tb(v.v_in_os(), v.v_move_slice_k_ic1(), s.s_gemm_k_num_c1(), + # s.s_move_slice_k_c1e(), s.s_in_stride_c(), s.s_in_stride_c_c1(), s.s_in_stride_c_c0_c1_diff())) + # return self._get_deferred() + + def move_slice_window_a(): + return '' + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + fctrl = ctrl_fma_main_loop_t() + fctrl.thread_m = self.tunable.thread_tile_m + fctrl.thread_n = self.tunable.thread_tile_n + fctrl.unroll_k = self.tunable.gemm_k_per_block + fctrl.label_prefix = self.name() + fctrl.gemm_m_repeat = self.tunable.gemm_m_repeat + fctrl.gemm_m_level0_cluster = self.tunable.gemm_m_level0_cluster + fctrl.gemm_m_level1_cluster = self.tunable.gemm_m_level1_cluster + fctrl.gemm_n_repeat = self.tunable.gemm_n_repeat + fctrl.gemm_n_level0_cluster = self.tunable.gemm_n_level0_cluster + fctrl.gemm_n_level1_cluster = self.tunable.gemm_n_level1_cluster + fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 + fctrl.lds_buffer_num = self.tunable.lds_buffer_num + + # functor + fctrl.global_load_a_functor = self.global_load_wei + fctrl.global_load_b_functor = self.global_load_in + fctrl.shared_store_a_functor = self.shared_store_wei + fctrl.shared_store_b_functor = self.shared_store_in + fctrl.shared_load_a_functor = inst_ds_read_t(self.tunable.thread_sub_tile_m * data_byte) + fctrl.shared_load_b_functor = inst_ds_read_t(self.tunable.thread_sub_tile_n * data_byte) + fctrl.move_slice_window_a_functor = move_slice_window_a + fctrl.move_slice_window_b_functor = move_slice_window_b + + # sympol type + fctrl.v_a = v.v_a + fctrl.v_b = v.v_b + fctrl.v_c = v.v_c + fctrl.v_gld_a = v.v_gld_a + fctrl.v_gld_b = v.v_gld_b + fctrl.v_sld_a_os = v.v_sld_a_os + fctrl.v_sld_b_os = v.v_sld_b_os + fctrl.v_sst_a_os = v.v_sst_a_os + fctrl.v_sst_b_os = v.v_sst_b_os + fctrl.s_kitr = s.s_kitr + fctrl.s_knum = s.s_knum + + fma_main_loop = fma_main_loop_t(self.mc, fctrl) + fma_main_loop.emit() + else: + a = self.agpr + fctrl = ctrl_mfma_main_loop_t() + ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile_fp32(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, + self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, + self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE) + fctrl.cxm = ctrl_xdlops_mapping + fctrl.unroll_k = self.tunable.gemm_k_per_block + fctrl.label_prefix = self.name() + fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 + fctrl.lds_buffer_num = self.tunable.lds_buffer_num + fctrl.local_prefetch_num = self.tunable.local_prefetch_num + fctrl.interleave = self.tunable.fma_interleave + + # functor + fctrl.global_load_a_functor = self.global_load_wei + fctrl.global_load_b_functor = self.global_load_in + fctrl.shared_store_a_functor = self.shared_store_wei + fctrl.shared_store_b_functor = self.shared_store_in + if ctrl_xdlops_mapping.wave_step_m == 1: + fctrl.shared_load_a_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load + else: + assert ctrl_xdlops_mapping.wave_step_m == 2, "currently only support wave_step_m is 2" + fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) + + if ctrl_xdlops_mapping.wave_step_n == 1: + fctrl.shared_load_b_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load + else: + assert ctrl_xdlops_mapping.wave_step_n == 2, "currently only support wave_step_n is 2" + fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) + fctrl.move_slice_window_a_functor = move_slice_window_a + fctrl.move_slice_window_b_functor = move_slice_window_b + + # sympol type + fctrl.v_a = v.v_a + fctrl.v_b = v.v_b + fctrl.a_c = a.a_c + fctrl.v_gld_a = v.v_gld_a + fctrl.v_gld_b = v.v_gld_b + fctrl.v_sld_a_os = v.v_sld_a_os + fctrl.v_sld_b_os = v.v_sld_b_os + fctrl.v_sst_a_os = v.v_sst_a_os + fctrl.v_sst_b_os = v.v_sst_b_os + fctrl.s_kitr = s.s_kitr + fctrl.s_knum = s.s_knum + + mfma_main_loop = mfma_main_loop_t(self.mc, fctrl) + mfma_main_loop.emit() + + + def emit_kernel_epilogue(self): + s = self.sgpr + v = self.vgpr + #label_out = f"L_{self.name()}_out" + + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + # if self.tunable.nxe != 0: + # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_in_stride_c(), s.s_tmp(), v.v_in_flag())) + # else: + # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_in_stride_c(), s.s_tmp())) + pass + else: + a = self.agpr + + # self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, + # None, s.s_out_stride_k(), s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_k(), v.v_cur_k(), s.s_block_gtc_ik(), v.v_co_sub_m_index(), v.v_tmp())) + + self._emit_front(f"{self.label_out}:") + + def emit_kernel_symbol(self): + self.karg.emit() + self._emit_empty_line() + self.sgpr.emit() + self._emit_empty_line() + self.vgpr.emit() + self._emit_empty_line() + if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: + self.agpr.emit() + self._emit_empty_line() + + def emit_kernel_header(self): + kernel_name = self.name() + self._emit('.text') + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V3: + self._emit('.globl {}'.format(kernel_name)) + self._emit('.p2align 8') + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V3: + self._emit('.type {},@function'.format(kernel_name)) + if self.mc.arch_config.code_object == AMDGPU_CODEOBJECT_V2: + self._emit('.amdgpu_hsa_kernel {}'.format(kernel_name)) + self._emit('{}:'.format(kernel_name)) + + def emit_kernel_body(self): + self.emit_kernel_prologue() + self.emit_kernel_fma_main_loop() + self.emit_kernel_epilogue() + def emit_kernel_end(self): + self._emit('s_endpgm') + def emit_kernel_footer(self): + self._emit_empty_line() + + def emit_kernel_amd_kernel_code_t(self): + amd_kernel_code_t(self.mc, self.get_kernel_info()).emit() diff --git a/igemm/codegen/compile.py b/igemm/codegen/compile.py index 5bf51dda..88a621a9 100644 --- a/igemm/codegen/compile.py +++ b/igemm/codegen/compile.py @@ -31,7 +31,7 @@ IGEMM_HOST_USE_GPU_NAIVE_CONV = True IGEMM_HOST_USE_XDNN = False -IGEMM_HOST_USE_MAGIC_DIV = True +IGEMM_HOST_USE_MAGIC_DIV = False IGEMM_HOST_USE_HIPCC = True # hipclang perfer use hipcc to compile host code def _check_hip_clang(): diff --git a/igemm/igemm_codegen_driver.py b/igemm/igemm_codegen_driver.py index 142d5078..0a9fce5f 100755 --- a/igemm/igemm_codegen_driver.py +++ b/igemm/igemm_codegen_driver.py @@ -49,7 +49,10 @@ def __init__(self, mc, tunable_dicts): for tdd in tunable_dicts: assert tdd['direction'] == 'fwd' # gtc fwd - kernel_list.extend([igemm_fwd_gtc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) + if 'tensor_layout' in tunable_dicts[0] and tunable_dicts[0]['tensor_layout'] == 'nhwc': + kernel_list.extend([igemm_fwd_gtc_nhwc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) + else: + kernel_list.extend([igemm_fwd_gtc_t(mc_asm_printer_t(mc.emitter, mc.arch_config), igemm_gtc_tunable_parameter_t(td)) for td in tunable_dicts]) elif tunable_dicts[0]['direction'] == 'bwd': for tdd in tunable_dicts: From 4731118fa8d705f92fa4f4ff22a2071bf0765776 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 14 Jan 2021 21:47:52 +0800 Subject: [PATCH 02/40] fixup following part of nhwc --- igemm/algo/igemm_fwd_gtc_nhwc.py | 252 +++++++++++++++++++------------ 1 file changed, 156 insertions(+), 96 deletions(-) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index a16e1cbe..b420b2d2 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -237,7 +237,7 @@ class macro_move_slice_window_k_e1_c_t(macro_base_t): def __init__(self, mc, tunable, inline = False): macro_base_t.__init__(self, mc, inline) self.tunable = tunable - #self.declare_arg("v_move_slice_k_iy") + self.declare_arg("v_move_slice_k_iy") self.declare_arg("v_move_slice_k_ix") self.declare_arg("v_move_slice_k_ic") #self.declare_arg("s_gemm_k_num_y") @@ -247,10 +247,10 @@ def __init__(self, mc, tunable, inline = False): self.declare_arg("v_in_os") self.declare_arg("v_wei_os") - self.declare_arg("s_in_stride_c") # this is indeed s_move_slice_k_c * data_byte - self.declare_arg("s_in_stride_gemm_k_num_c") - self.declare_arg("s_in_stride_diff_x") # indeed stride_x - stride_c, always possitive - self.declare_arg("s_in_stride_diff_y") # indeed stride_y - stride_x, always possitive + # self.declare_arg("s_in_stride_gemm_k_num_c") + self.declare_arg("s_move_slice_k_in_stride_diff_y") # indeed stride_y - stride_x, always possitive + self.declare_arg("s_move_slice_k_in_stride_diff_x") # indeed stride_x - stride_c, always possitive + self.declare_arg("s_move_slice_k_stride_c") # this is indeed s_move_slice_k_c * data_byte, same for input/weight self.declare_arg("v_in_ihi") # need update self.declare_arg("v_in_iwi") # need update @@ -265,23 +265,43 @@ def name(self): def expr(self): self._emit(f"v_add_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_move_slice_k_c()}], v[{self.v_move_slice_k_ic()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_in_stride_c()}], v[{self.v_in_os()}]") - self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_wei_stride_c()}], v[{self.v_wei_os()}]") # weight offset always increase, treat y*x*c as single dimension + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_in_os()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") # weight offset always increase, treat y*x*c as single dimension self._emit(f"v_cmpx_le_u32 vcc, s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") self._emit(f"v_subrev_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") self._emit(f"v_add_u32 v[{self.v_move_slice_k_ix()}], 1, v[{self.v_move_slice_k_ix()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_in_stride_diff_x()}], v[{self.v_in_os()}]") # merge with above c + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_in_stride_diff_x()}], v[{self.v_in_os()}]") # merge with above c self._emit(f"v_add_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_wi()}], v[{self.v_in_iwi()}]") self._emit(f"s_mov_b64 exec, -1") self._emit_empty_line() self._emit(f"v_cmpx_le_u32 vcc s[{self.s_gemm_k_num_x()}], v[{self.v_move_slice_k_ix()}]") - self._emit(f"v_add_u32 v[{self.v_in_os_base()}], s[{self.s_in_stride_diff_y()}], v[{self.v_in_os_base()}]") + self._emit(f"v_add_u32 v[{self.v_move_slice_k_iy()}], 1, v[{self.v_move_slice_k_iy()}]") + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_in_stride_diff_y()}], v[{self.v_in_os()}]") self._emit(f"v_subrev_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_sub_wi()}], v[{self.v_in_iwi()}]") self._emit(f"v_add_u32 v[{self.v_in_ihi()}], s[{self.s_in_diff_hi()}], v[{self.v_in_ihi()}]") self._emit(f"s_mov_b64 exec, -1") self._emit_empty_line() # free of last dim check + class macro_move_slice_window_k_nxe0_c_t(macro_base_t): + ''' + used for nxe=0. only c move is needed + ''' + def __init__(self, mc, tunable, inline = False): + macro_base_t.__init__(self, mc, inline) + self.tunable = tunable + self.declare_arg("v_in_os") + self.declare_arg("v_wei_os") + self.declare_arg("s_move_slice_k_stride_c") # this is indeed s_move_slice_k_c * data_byte + + def name(self): + return '.v_fwd_gtc_nhwc_move_slice_window_k_nxe0_c' + + def expr(self): + self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_in_os()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + self._emit_empty_line() + class global_load_in_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) @@ -468,13 +488,11 @@ def __init__(self, mc, outer): self.s_wei_stride_c = sym_t("s_wei_stride_c" , self.s_stride_c.value) # stride for out - self.s_out_stride_ho = sym_t('s_out_stride_ho' , sseq(1)) + self.s_out_stride_wo = sym_t('s_out_stride_wo' , sseq(1)) self.s_out_stride_n = sym_t('s_out_stride_n' , sseq(1)) if ta_n0 != 1: self.s_out_stride_n0 = sym_t('s_out_stride_n0' , sseq(1)) - - self.s_in_stride_c_c1 = sym_t("s_in_stride_c_c1" , sseq(1)) self.s_in_stride_c_c0_c1_diff = sym_t("s_in_stride_c_c0_c1_diff" , sseq(1)) @@ -485,16 +503,26 @@ def __init__(self, mc, outer): self.s_move_slice_k_c1e = sym_t("s_move_slice_k_c1e" , sseq(1)) if outer.tunable.nxe != 0: - self.s_move_slice_k_c1 = sym_t("s_move_slice_k_c1" , sseq(1)) + self.s_move_slice_k_c = sym_t("s_move_slice_k_c" , sseq(1)) self.s_move_slice_k_y = sym_t("s_move_slice_k_y" , sseq(1)) self.s_move_slice_k_x = sym_t("s_move_slice_k_x" , self.s_block_gtc_ig.value) + self.s_move_slice_k_stride_c = sym_t("s_move_slice_k_stride_c" , sseq(1)) + self.s_in_diff_sub_wi = sym_t("s_in_diff_sub_wi" , sseq(1)) + if outer.tunable.nxe != 0: + self.s_move_slice_k_in_stride_diff_y = sym_t("s_move_slice_k_in_stride_diff_y" , sseq(1)) + self.s_move_slice_k_in_stride_diff_x = sym_t("s_move_slice_k_in_stride_diff_x" , sseq(1)) + + self.s_knum = sym_t("s_knum" , 3) - self.s_gemm_k_num_c1 = sym_t("s_gemm_k_num_c1" , sseq(1)) + self.s_gemm_k_num_c = sym_t("s_gemm_k_num_c" , sseq(1)) if outer.tunable.nxe != 0: self.s_gemm_k_num_y = sym_t("s_gemm_k_num_y" , self.s_y.value) self.s_gemm_k_num_x = sym_t("s_gemm_k_num_x" , self.s_x.value) + # self.s_move_slice_k_in_stride_diff_y = sym_t("s_move_slice_k_in_stride_diff_y" , sseq(1)) + # self.s_move_slice_k_in_stride_diff_x = sym_t("s_move_slice_k_in_stride_diff_x" , sseq(1)) + #if outer.tunable.nxe != 0: self.s_dim_b = sym_t("s_dim_b" , sseq(1)) @@ -546,17 +574,17 @@ def __init__(self, mc, outer): is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS vseq = gpr_sequencer_t() if is_vgpr_acc_c: - self.v_c = sym_t("v_c" ,vseq(outer.tunable.num_vgpr_accumulate_c)) - v_c_num = vseq() + self.v_c = sym_t("v_c" ,vseq(outer.tunable.num_vgpr_accumulate_c)) + v_c_num = vseq() else: - v_c_resuable_num = outer.tunable.num_vgpr_accumulate_a + outer.tunable.num_vgpr_accumulate_b + \ - outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ - 16 # from v_sst_a_os to v_co_sst - v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups - v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 + v_c_resuable_num = outer.tunable.num_vgpr_accumulate_a + outer.tunable.num_vgpr_accumulate_b + \ + outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ + 16 # from v_sst_a_os to v_co_sst + v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups + v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 - v_c_needed = v_c_needed if v_c_needed > 2 else 2 # let at least 2 - self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") + v_c_needed = v_c_needed if v_c_needed > 2 else 2 # let at least 2 + self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") self.v_a = sym_t("v_a" ,vseq(outer.tunable.num_vgpr_accumulate_a)) self.v_b = sym_t("v_b" ,vseq(outer.tunable.num_vgpr_accumulate_b)) @@ -583,42 +611,42 @@ def __init__(self, mc, outer): self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) - self.v_out_os = sym_t("v_out_os" ,vseq(1)) + self.v_out_os = sym_t("v_out_os" ,vseq(1)) if outer.tunable.nxe != 0: - self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) - self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) - self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) - self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) - - self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) - self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) - self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) - self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) + self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) + self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) + self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) + self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) + + self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) + self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) + self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) + self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) if outer.tunable.nxe != 0: self.v_in_iy = sym_t("v_in_iy" ,vseq(1)) self.v_in_ix = sym_t("v_in_ix" ,vseq(1)) - self.v_move_slice_k_ic1 = sym_t("v_move_slice_k_ic1" , self.v_gtc_ta_ic.value) + self.v_move_slice_k_ic = sym_t("v_move_slice_k_ic1" , self.v_gtc_ta_ic.value) if outer.tunable.nxe != 0: - self.v_move_slice_k_iy = sym_t("v_move_slice_k_iy", self.v_in_iy.value) - self.v_move_slice_k_ix = sym_t("v_move_slice_k_ix", self.v_in_ix.value) + self.v_move_slice_k_iy = sym_t("v_move_slice_k_iy", self.v_in_iy.value) + self.v_move_slice_k_ix = sym_t("v_move_slice_k_ix", self.v_in_ix.value) - self.v_gemm_in = sym_t("v_gemm_in" , vseq(1)) - self.v_gemm_im = sym_t("v_gemm_im" , vseq(1)) + self.v_gemm_in = sym_t("v_gemm_in" , vseq(1)) + self.v_gemm_im = sym_t("v_gemm_im" , vseq(1)) - self.v_out_iho = sym_t("v_out_iho" ,vseq(1)) - self.v_out_iwo = sym_t("v_out_iwo" ,vseq(1)) - self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) - self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) + self.v_out_iho = sym_t("v_out_iho" ,vseq(1)) + self.v_out_iwo = sym_t("v_out_iwo" ,vseq(1)) + self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) + self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) - self.v_cur_k = sym_t("v_cur_k" ,vseq(1)) + self.v_cur_k = sym_t("v_cur_k" ,vseq(1)) - self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) - total_vgpr = vseq() + self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + total_vgpr = vseq() if outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if xdlops agpr is larger than vgpr usage, must change vgpr count to agpr - total_vgpr = max(total_vgpr, outer.tunable.num_agpr_accumulate_c) - self.v_end = sym_t("v_end" ,total_vgpr) + total_vgpr = max(total_vgpr, outer.tunable.num_agpr_accumulate_c) + self.v_end = sym_t("v_end" ,total_vgpr) def get_count(self): return self.v_end.value @@ -879,12 +907,14 @@ def get_macro_in_update_os(self): def get_macro_move_slice_window(self): inline = True if self.tunable.fma_interleave else False - move_slice_window = self.macro_move_slice_window_k_e1_c_t(self.mc, self.tunable, inline) + if self.tunable.nxe != 0: + move_slice_window = self.macro_move_slice_window_k_e1_c_t(self.mc, self.tunable, inline) + else: + move_slice_window = self.macro_move_slice_window_k_nxe0_c_t(self.mc, self.tunable, inline) # return single functor ! return move_slice_window - def get_macro_set_flag_hw(self): inline = True if self.tunable.fma_interleave else False return self.macro_set_flag_hw(self.mc, inline) @@ -1134,6 +1164,10 @@ def emit_kernel_prologue(self): self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") + if self.tunable.nxe != 0: + self._emit(f"v_mov_b32 v[{v.v_in_iy()}], 0") + self._emit(f"v_mov_b32 v[{v.v_in_ix()}], 0") + self._emit(f"s_waitcnt lgkmcnt(0)") self._emit_empty_line() if IGEMM_GTC_FEAT_MAGIC_DIVISION: @@ -1155,9 +1189,9 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_wei_stride_y()}], s[{s.s_x()}], s[{s.s_c()}]") self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_wei_stride_y()}], s[{s.s_y()}]") # output - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_ho()}], s[{s.s_wo()}], s[{s.s_tmp()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho()}], s[{s.s_out_stride_ho()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo()}], s[{s.s_out_stride_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho()}], s[{s.s_tmp(1)}]") if ta_n0 != 1: self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") @@ -1171,9 +1205,9 @@ def emit_kernel_prologue(self): # weight self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") # output - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_ho()}], s[{s.s_wi()}], s[{s.s_tmp()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_hi()}], s[{s.s_out_stride_ho()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wi()}], s[{s.s_out_stride_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(1)}]") if ta_n0 != 1: self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") @@ -1433,6 +1467,10 @@ def emit_kernel_prologue(self): self._emit(self.coalescing_store.init_co_sub_n_index(v.v_co_sub_n_index(), '0', v.v_tmp())) self._emit_empty_line() + ''' + a good news for nhwc and coalescing output is that, we can treat gemm_m (n*ho*wo) as a single dimension, + and use sgpr to stride along this dimension. this is much easier + ''' self._emit(f"; output offset") self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_k()}]") @@ -1459,20 +1497,20 @@ def emit_kernel_prologue(self): if na_n0 != 1: self._emit(f"v_lshrrev_b32 v[{v.v_out_in0()}], {igemm_log2(na_n1b)}, v[{v.v_co_sub_m_index()}] ; => N0") else: - assert na_n0 == self.tunable.block_size assert False, "un implemented, should rarely be used" else: - if na_n0 != 1: - self._emit(f"v_and_b32 v[{v.v_out_in0()}], {na_n0 - 1}, v[{v.v_co_sub_m_index()}] ; => N0") - if na_n1b != 1: - self._emit(f"v_lshrrev_b32 v[{v.v_out_in1b()}], {igemm_log2(na_n0)}, v[{v.v_co_sub_m_index()}] ; => N1B") - else: - assert False, "un implemented, should rarely be used" - else: - if na_n1b != 1: - self._emit(f"v_mov_b32 v[{v.v_out_in1b()}], v[{v.v_co_sub_m_index()}] ; => N1B") - else: - assert False, "un implemented, should rarely be used" + assert False + # if na_n0 != 1: + # self._emit(f"v_and_b32 v[{v.v_out_in0()}], {na_n0 - 1}, v[{v.v_co_sub_m_index()}] ; => N0") + # if na_n1b != 1: + # self._emit(f"v_lshrrev_b32 v[{v.v_out_in1b()}], {igemm_log2(na_n0)}, v[{v.v_co_sub_m_index()}] ; => N1B") + # else: + # assert False, "un implemented, should rarely be used" + # else: + # if na_n1b != 1: + # self._emit(f"v_mov_b32 v[{v.v_out_in1b()}], v[{v.v_co_sub_m_index()}] ; => N1B") + # else: + # assert False, "un implemented, should rarely be used" self._emit(f"; compute from n1b") self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_out_in1b()}]") @@ -1528,13 +1566,38 @@ def emit_kernel_prologue(self): self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_co_sub_n_index()}]") # n, add to k self._emit(f"; add ho, wo") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}] ; stride for wo") self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], v[{v.v_out_iho()}]") - self._emit(f"v_add3_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp(1)}], v[{v.v_out_iwo()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(2)}], v[{v.v_tmp(1)}], v[{v.v_out_iwo()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_tmp()}], v[{v.v_tmp(2)}]") + self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") self._emit(f"v_lshlrev_b32 v[{v.v_out_os()}], {igemm_log2(data_byte)}, v[{v.v_out_os()}]") if self.tunable.nxe != 0: self._emit(m_set_flag_hw(v.v_out_flag(), v.v_out_iho(), v.v_out_iwo(), s.s_ho(), s.s_wo())) self._emit(f"; move slice stride") + self._emit(f"s_mov_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}]") + if self.tunable.nxe != 0: + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c()}], {na_c}") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_move_slice_k_c()}], {igemm_log2(data_byte)}") + else: + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") + + if self.tunable.nxe != 0: + self._emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[{s.s_c()}], {igemm_log2(data_byte)}") + # diff_y, ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_wi()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_dilation_h()}], s[{s.s_tmp(4)}]") + self._emit(self.try_shift_stride(s.s_move_slice_k_in_stride_diff_y, igemm_log2(data_byte))) + self._emit(f"s_sub_u32 s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_tmp(2)}]") + # diff_x, iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w, hence need compute s_dilation_w per increase + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_dilation_w()}], s[{s.s_tmp(4)}]") + self._emit(self.try_shift_stride(s.s_move_slice_k_in_stride_diff_x, igemm_log2(data_byte))) + self._emit(f"s_sub_u32 s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_tmp(2)}]") + + self._emit(f"s_mul_i32 s[{s.s_in_diff_sub_wi()}], s[{s.s_x()}], s[{s.s_dilation_w()}] ") + # assert na_c0 * na_c1e == self.tunable.gemm_k_per_block and nb_c0 * nb_c1e == self.tunable.gemm_k_per_block # if self.tunable.nxe != 0: @@ -1581,9 +1644,9 @@ def emit_kernel_prologue(self): # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") self._emit_empty_line() - self._emit(self.try_shift_stride(s.s_in_stride_c_c1, igemm_log2(data_byte))) + #self._emit(self.try_shift_stride(s.s_in_stride_c_c1, igemm_log2(data_byte))) #self._emit(self.try_shift_stride(s.s_wei_stride_k_k1, igemm_log2(data_byte))) - self._emit(self.try_shift_stride(s.s_in_stride_c_c0_c1_diff, igemm_log2(data_byte))) + #self._emit(self.try_shift_stride(s.s_in_stride_c_c0_c1_diff, igemm_log2(data_byte))) #self._emit(self.try_shift_stride(s.s_wei_stride_k_k0_k1_diff, igemm_log2(data_byte))) if self.tunable.nxe != 0: @@ -1607,30 +1670,24 @@ def emit_kernel_fma_main_loop(self): # m_move_slice_window_ta, m_move_slice_window_tb = self.get_macro_move_slice_window() m_move_slice_window = self.get_macro_move_slice_window() + m_set_flag_hw = self.get_macro_set_flag_hw() def move_slice_window_b(): - return '' - - # if self.tunable.nxe != 0: - # m_in_update_os = self.get_macro_in_update_os() - # m_in_update_hw = self.get_macro_in_update_hw() - # m_set_flag_hw = self.get_macro_set_flag_hw() - # with self._deferred_context(): - # self._emit(m_move_slice_window_tb(v.v_move_slice_k_ic1(), v.v_move_slice_k_iy(), v.v_move_slice_k_ix(), s.s_gemm_k_num_c1(), s.s_gemm_k_num_y(), s.s_gemm_k_num_x(), - # s.s_move_slice_k_c1(), s.s_move_slice_k_y(), s.s_move_slice_k_x(), v.v_in_os_base(), - # s.s_in_stride_c(), s.s_in_stride_c_c1(), s.s_in_stride_c_c0_c1_diff())) - # self._emit(m_in_update_hw(v.v_in_ihi(), v.v_in_iwi(), v.v_in_iho(), v.v_in_iwo(), v.v_in_iy(), v.v_in_ix(), s.s_dilation_h(), s.s_dilation_w())) - # self._emit(m_in_update_os(v.v_in_os(), v.v_in_os_base(), v.v_in_ihi(), v.v_in_iwi(), s.s_wi(), v.v_tmp())) - # self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) - # return self._get_deferred() - # else: - # with self._deferred_context(): - # if self.is_1d_move_slice_k(): - # self._emit(m_move_slice_window_tb(v.v_in_os(), s.s_move_slice_k_c1e(), s.s_in_stride_c(), s.s_in_stride_c_c1())) - # else: - # self._emit(m_move_slice_window_tb(v.v_in_os(), v.v_move_slice_k_ic1(), s.s_gemm_k_num_c1(), - # s.s_move_slice_k_c1e(), s.s_in_stride_c(), s.s_in_stride_c_c1(), s.s_in_stride_c_c0_c1_diff())) - # return self._get_deferred() + ''' + in nhwc we only need call one move slice window + ''' + if self.tunable.nxe != 0: + with self._deferred_context(): + self._emit(m_move_slice_window(v.v_move_slice_k_iy(), v.v_move_slice_k_ix(), v.v_move_slice_k_ic(), + s.s_gemm_k_num_x(), s.s_gemm_k_num_c(), s.s_move_slice_k_c(), v.v_in_os(), v.v_wei_os(), + s.s_move_slice_k_in_stride_diff_y(), s.s_move_slice_k_in_stride_diff_x(), s.s_move_slice_k_stride_c(), + v.v_in_ihi(), v.v_in_iwi(), s.s_dilation_h(), s.s_dilation_w(), s.s_in_diff_sub_wi())) + self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) + return self._get_deferred() + else: + with self._deferred_context(): + self._emit(m_move_slice_window(v.v_in_os(), v.v_wei_os(),s.s_move_slice_k_stride_c())) + return self._get_deferred() def move_slice_window_a(): return '' @@ -1731,6 +1788,9 @@ def emit_kernel_epilogue(self): v = self.vgpr #label_out = f"L_{self.name()}_out" + ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if self.tunable.nxe != 0: # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, @@ -1738,12 +1798,12 @@ def emit_kernel_epilogue(self): # else: # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_in_stride_c(), s.s_tmp())) - pass + assert False else: a = self.agpr - - # self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, - # None, s.s_out_stride_k(), s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_k(), v.v_cur_k(), s.s_block_gtc_ik(), v.v_co_sub_m_index(), v.v_tmp())) + self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, + s.s_out_stride_n0() if ta_n0 != 1 else None, s.s_out_stride_wo(), + s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_k(), v.v_cur_k(), s.s_block_gtc_ik(), v.v_co_sub_m_index(), v.v_tmp())) self._emit_front(f"{self.label_out}:") From 6fa9c74d6ed4c405dd6425b20a322dde2c19b534 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 18 Jan 2021 15:47:50 +0800 Subject: [PATCH 03/40] host build --- driver/args.h | 3 + driver/conv_driver.cpp | 87 ++++++++++++----- driver/igemm_fwd_gtc_driver.h | 176 +++++++++++++++++++++++----------- 3 files changed, 189 insertions(+), 77 deletions(-) diff --git a/driver/args.h b/driver/args.h index e873af82..079ce7eb 100644 --- a/driver/args.h +++ b/driver/args.h @@ -165,6 +165,9 @@ static inline args_t create_conv_args(int argc, char *argv[]) { } args_t args; + args.insert_arg("in_layout", 'I', "NCHW", "Input Layout (Default=NCHW)", "string"); + args.insert_arg("out_layout", 'O', "NCHW", "Output Layout (Default=NCHW)", "string"); + args.insert_arg("fil_layout", 'f', "NCHW", "Input Layout (Default=NCHW)", "string"); args.insert_arg("spatial_dim", '_', "2", "convolution spatial dimension (Default-2)", "int"); args.insert_arg("forw", 'F', "0", "Flag enables fwd, bwd, wrw convolutions" diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 36803082..f431008d 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -57,18 +57,8 @@ # define IGEMM_GPU_NAIVE_CONV_HSACO "naive_conv.hsaco" # endif #else -# ifdef USE_XDNN -# include "xdnn_conv.h" -# define conv_fwd_nchw xdnn_conv_fwd_nchw -# define conv_bwd_nchw xdnn_conv_bwd_nchw -# define conv_wrw_nchw xdnn_conv_wrw_nchw -# else -# define NAIVE_CONV_THREADED -# include "naive_conv.h" -# define conv_fwd_nchw naive_conv_fwd_nchw -# define conv_bwd_nchw naive_conv_bwd_nchw -# define conv_wrw_nchw naive_conv_wrw_nchw -# endif +# define NAIVE_CONV_THREADED +# include "naive_conv.h" #endif static inline size_t conv_out_size(size_t in_size, size_t pad, size_t dilation, @@ -396,11 +386,19 @@ int main(int argc, char **argv) { int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int forw = conv_args.get_int("forw"); + std::string in_layout = conv_args.get_str("in_layout"); + std::string out_layout = conv_args.get_str("out_layout"); + std::string fil_layout = conv_args.get_str("fil_layout"); int need_fwd = (forw == 0 ? 1 : (forw & 1 ? 1 : 0)); int need_bwd = (forw == 0 ? 1 : (forw & 2 ? 1 : 0)); int need_wrw = (forw == 0 ? 1 : (forw & 4 ? 1 : 0)); + assert(in_layout == out_layout && in_layout == fil_layout); // currently only support all layout is the same + assert(in_layout == "NCHW" || in_layout == "NHWC"); // currently only support these layout + assert((in_layout == "NCHW" && tunables[0].tensor_layout == "nchw") || + (in_layout == "NHWC" && tunables[0].tensor_layout == "nhwc")); // check pairs + // init host side float *host_input = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); float *host_weight = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); @@ -465,19 +463,34 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - - gpu_naive_conv_fwd_nchw_fp32(device_input, device_weight, device_output, + + if(in_layout == "NCHW") + gpu_naive_conv_fwd_nchw_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_fwd_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_output, device_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_fwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, + if(in_layout == "NCHW") + naive_conv_fwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_fwd_nhwc(host_input, host_weight, host_output, n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_output_to_host = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); } @@ -495,9 +508,9 @@ int main(int argc, char **argv) { printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str()); fflush(stdout); - //if (need_verify) - // HIP_CALL(hipMemset(device_output, 0, - // n * c * ho * wo * sizeof(float))); + if (need_verify) + HIP_CALL(hipMemset(device_output, 0, n * c * ho * wo * sizeof(float))); + result_t result = conv_fwd_driver.run(&conv_args, tunable, module, device_input, device_weight, device_output, warmup, repeat); @@ -561,18 +574,33 @@ int main(int argc, char **argv) { static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - gpu_naive_conv_bwd_nchw_fp32(device_input, device_weight, device_output, + if(in_layout == "NCHW") + gpu_naive_conv_bwd_nchw_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_bwd_nhwc_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_input, device_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_bwd_nchw(host_input, host_weight, host_output, n, + if(in_layout == "NCHW") + naive_conv_bwd_nchw(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_bwd_nhwc(host_input, host_weight, host_output, n, + wi, hi, c, k, x, y, pad_w, + pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_input_to_host = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); // printf("len:%d\n", n * c * hi * wi * sizeof(float) ); @@ -657,18 +685,33 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_output, host_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); - gpu_naive_conv_wrw_nchw_fp32(device_input, device_weight, device_output, + if(in_layout == "NCHW") + gpu_naive_conv_wrw_nchw_fp32(device_input, device_weight, device_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + gpu_naive_conv_wrw_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + else + assert(0); HIP_CALL(hipDeviceSynchronize()); HIP_CALL(hipMemcpy(host_weight, device_weight, static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * sizeof(float), hipMemcpyDeviceToHost)); #else - conv_wrw_nchw(host_input, host_weight, host_output, n, + if(in_layout == "NCHW") + naive_conv_wrw_nchw(host_input, host_weight, host_output, n, wi, hi, c, k, x, y, pad_w, pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else if(in_layout == "NHWC") + naive_conv_wrw_nhwc(host_input, host_weight, host_output, n, + wi, hi, c, k, x, y, pad_w, + pad_h, stride_w, stride_h, dilation_w, dilation_h, ngroups); + else + assert(0); #endif device_weight_to_host = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); // printf("len:%d\n", k * c * y * x * sizeof(float)); diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 32bd90d2..fb5106af 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -148,11 +148,21 @@ class igemm_fwd_gtc_t { int nxb = tunable->nxb; int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; - + int gemm_m = 0; + int gemm_n = 0; + + if(tunable->tensor_layout == "nchw"){ + gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + gemm_n = n * b; + }else if (tunable->tensor_layout == "nhwc"){ + gemm_m = n * b; + // gemm_n = ((k/group + gemm_n_per_block -1)/gemm_n_per_block) * gemm_n_per_block; + gemm_n = k / group; + }else{ + assert(false); + } size_t grid_size = static_cast(group) * utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * - utility_integer_divide_ceil(gemm_n, gemm_n_per_block); + utility_integer_divide_ceil(gemm_n, gemm_n_per_block); assert(grid_size <= 0xffffffffUL); return grid_size; } @@ -188,57 +198,113 @@ class igemm_fwd_gtc_t { int nxb = tunable->nxb; int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; - int gemm_k = (c / group) * y * x; - bool unit_conv = (x==1)&&(y==1)&&(stride_h==1)&&(stride_w==1)&&(dilation_h==1)&&(dilation_w==1)&&(pad_h==0)&&(pad_w==0); - // support pad to modulo, hence only check when nxe is 0 - if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) - { - return false; - } - - if(gemm_n_per_block % tunable->nxb != 0){ - //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } - - if(n % (gemm_n_per_block / tunable->nxb) != 0){ - //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } - - if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ - return false; - } - - if((nxe == 0) && !unit_conv){ - return false; - } - - // input vector load limitation, n1b - if(tunable->tensor_b_thread_lengths[3] > 1 && ( - !unit_conv || - unit_conv && (hi * wi) % tunable->tensor_b_thread_lengths[3] != 0)) { - return false; - } - - // weight vector load limitation, c1e - if(tunable->tensor_a_thread_lengths[1] > 1 && - gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ - return false; - } - - // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. - if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ - return false; - } - - // if t_c0 > 1, need to check gemmk per block - if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ - return false; + if(tunable->tensor_layout == "nchw"){ + int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + int gemm_n = n * b; + int gemm_k = (c / group) * y * x; + + // support pad to modulo, hence only check when nxe is 0 + if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) + { + return false; + } + + if(gemm_n_per_block % tunable->nxb != 0){ + //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if(n % (gemm_n_per_block / tunable->nxb) != 0){ + //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ + return false; + } + + if((nxe == 0) && !unit_conv){ + return false; + } + + // input vector load limitation, n1b + if(tunable->tensor_b_thread_lengths[3] > 1 && ( + !unit_conv || + unit_conv && (hi * wi) % tunable->tensor_b_thread_lengths[3] != 0)) { + return false; + } + + // weight vector load limitation, c1e + if(tunable->tensor_a_thread_lengths[1] > 1 && + gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ + return false; + } + + // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. + if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ + return false; + } + + // if t_c0 > 1, need to check gemmk per block + if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ + return false; + } + }else if(tunable->tensor_layout == "nhwc"){ + int gemm_m = n * b; + // int gemm_n = ((k/group + gemm_n_per_block -1)/gemm_n_per_block) * gemm_n_per_block; + int gemm_n = k / group; + int gemm_k = (c / group) * y * x; + + // support pad to modulo, hence only check when nxe is 0 + if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) + { + return false; + } + + if(gemm_m_per_block % tunable->nxb != 0){ + //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if(n % (gemm_m_per_block / tunable->nxb) != 0){ + //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + return false; + } + + if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ + return false; + } + + if((nxe == 0) && !unit_conv){ + return false; + } + + // input vector load limitation, n1b + if(tunable->tensor_a_thread_lengths[3] > 1 && ( + !unit_conv || + unit_conv && (hi * wi) % tunable->tensor_a_thread_lengths[3] != 0)) { + return false; + } + + // // weight vector load limitation, c1e + // if(tunable->tensor_a_thread_lengths[1] > 1 && + // gemm_k % tunable->tensor_a_thread_lengths[1] != 0){ + // return false; + // } + + // // if tb_c1e > 1, only 1x1 case is runable, it can not check gemm_k_padding either. + // if(tunable->tensor_b_thread_lengths[1] > 1 && (( x !=1 || y != 1)||(gemm_k % gemm_k_per_block != 0))){ + // return false; + // } + + // // if t_c0 > 1, need to check gemmk per block + // if(tunable->tensor_b_thread_lengths[0] > 1 && (gemm_k % gemm_k_per_block != 0)){ + // return false; + // } + }else{ + assert(0); } return true; } @@ -303,10 +369,10 @@ class igemm_fwd_gtc_t { karg.x = x; karg.group = group; - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; #if USE_MAGIC_DIV + int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + int gemm_n = n * b; { // init magic division parameters uint32_t nb_n0 = tunable->tensor_b_cluster_lengths[2] * tunable->tensor_b_thread_lengths[2]; From 1b3df1346b20a2a0fd985b39675d982fb205487a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 18 Jan 2021 18:31:36 +0800 Subject: [PATCH 04/40] fix crash --- driver/conv_driver.cpp | 10 +++--- igemm/algo/igemm_fwd_gtc_nhwc.py | 53 +++++++++++++++++--------------- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index f431008d..93244774 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -252,13 +252,15 @@ static inline bool valid_vector(const float *ref, const float *pred, size_t n, double s1 = 0.0; int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + int igemm_valid_float = env_get_int("VALID_FLOAT", 1); size_t pp_err = 0; for (size_t i = 0; i < n; ++i) { - if(!(valid_float(ref[i]) && valid_float(pred[i]))){ - printf(" invalid float at %zu, ref:%f, pred:%f\n", i, ref[i], pred[i]); - return false; - } + if(igemm_valid_float) + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + printf(" invalid float at %zu, ref:%f, pred:%f\n", i, ref[i], pred[i]); + return false; + } double ri = (double)ref[i]; double pi = (double)pred[i]; double d = ri - pi; diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index b420b2d2..a98723c4 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -319,13 +319,13 @@ def __call__(self): with self._deferred_context(): self._emit(f"; load input") if self.outer.tunable.nxe != 0: - self._emit(f".v_clear_nc {v.v_gld_b()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") + self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") self._emit(f"v_cmp_eq_u32 vcc, 1, v[{v.v_in_flag()}]") self._emit(f"s_and_saveexec_b64 s[{s.s_tmp(4)}:{s.s_tmp(5)}], vcc") if self.outer.tunable.precache_soffset: - self._emit(m_in_2d_global_load(v.v_gld_b(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset())) + self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset())) else: - self._emit(m_in_2d_global_load(v.v_gld_b(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_tmp())) + self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_tmp())) if self.outer.tunable.nxe != 0: self._emit(f"s_or_b64 exec, exec, s[{s.s_tmp(4)}:{s.s_tmp(5)}]") return self._get_deferred() @@ -360,9 +360,9 @@ def __call__(self): self._emit(f"; load weight") # self._emit(f".v_clear_nc {v.v_gld_a()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") if self.outer.tunable.precache_soffset: - self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset())) + self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset())) else: - self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_tmp())) + self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_tmp())) return self._get_deferred() class shared_store_in_t(mc_base_t): @@ -484,8 +484,6 @@ def __init__(self, mc, outer): if outer.tunable.nxe != 0: self.s_wei_stride_y = sym_t('s_wei_stride_y' , sseq(1)) self.s_stride_c = sym_t('s_stride_c' , sseq(1)) - self.s_in_stride_c = sym_t("s_in_stride_c" , self.s_stride_c.value) - self.s_wei_stride_c = sym_t("s_wei_stride_c" , self.s_stride_c.value) # stride for out self.s_out_stride_wo = sym_t('s_out_stride_wo' , sseq(1)) @@ -930,12 +928,12 @@ def get_symbol_global_load_s_stride_d0_d1(self): in_stride_gprs = [s.s_in_stride_n0 if ta_n0 != 1 else s_dummy, s.s_in_stride_wi, s_dummy, - s.s_in_stride_c] + s.s_stride_c] # [tb_k, ta_e, ta_c] wei_stride_gprs = [s.s_wei_stride_k, s_dummy, - s.s_wei_stride_c] + s.s_stride_c] if self.in_thread_copy_ndim == 2: s_in_stride_d0 = in_stride_gprs[in_thread_copy_index[0]] @@ -1145,7 +1143,7 @@ def emit_kernel_prologue(self): self._emit(f"s_load_dword s[{s.s_magic_6()}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_6()}") self._emit(f"s_load_dwordx2 s[{s.s_shift_pack_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") - self._emit(f"; in(e, c, n0, n1b) thread_lengths: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, cluster_length: {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}") + self._emit(f"; in(e, c, n0, n1b) thread_lengths: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, cluster_length: {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}, unmerge_sub_n:{unmerge_sub_n}, unmerge_sub_n1:{unmerge_sub_n1}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) if ca_n0 != 1: @@ -1197,8 +1195,8 @@ def emit_kernel_prologue(self): else: # input - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_c()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") if ta_n0 != 1: self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") @@ -1272,9 +1270,11 @@ def emit_kernel_prologue(self): self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + # to compute ho*wo*sub_n1 // na_n1b if unmerge_sub_n1 == 1: self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_n1b)} ; total number of n1b") else: + assert na_n1b >= unmerge_sub_n1 if unmerge_sub_n1 == na_n1b: self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of n1b") else: @@ -1282,9 +1282,9 @@ def emit_kernel_prologue(self): if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(4), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) + self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) else: - self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(4), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) if na_n1b != 1: self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_n1b)}") @@ -1349,6 +1349,7 @@ def emit_kernel_prologue(self): # self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") # s_in_stride_wi need shift before! + self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) if self.tunable.nxe != 0: self._emit(f"v_add_lshl_u32 v[{v.v_in_os_base()}], v[{v.v_gtc_ta_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") @@ -1378,9 +1379,10 @@ def emit_kernel_prologue(self): # load in self._emit(self.global_load_in()) self._emit_empty_line() + self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") # config weight range - self._emit("; config for weight range") + #self._emit("; config for weight range") #self._emit(f"s_mul_i32 s[{s.s_p_wei(2)}], s[{s.s_wei_stride_k() if self.tunable.nxe != 0 else s.s_c()}], s[{s.s_k()}]") #self._emit(f"s_lshl_b32 s[{s.s_p_wei(2)}], s[{s.s_p_wei(2)}], {igemm_log2(data_byte)}") self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") @@ -1512,6 +1514,7 @@ def emit_kernel_prologue(self): # else: # assert False, "un implemented, should rarely be used" + # TODO: extend tensor size, here vgpr only have 32bit self._emit(f"; compute from n1b") self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_out_in1b()}]") if self.tunable.nxe != 0: @@ -1605,11 +1608,11 @@ def emit_kernel_prologue(self): # self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1e()}], {na_c0 * na_c1e}") # if IGEMM_GTC_FEAT_MAGIC_DIVISION: # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") - # self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_magic_2(), s.s_tmp(3), s.s_wei_stride_c(), s.s_tmp())) + # self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_magic_2(), s.s_tmp(3), s.s_stride_c(), s.s_tmp())) # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") # self._emit(m_mdiv_u32_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_magic_3(), s.s_tmp(3), s.s_x(), s.s_tmp())) # else: - # self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_wei_stride_c(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) + # self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_stride_c(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) # self._emit(m_int_div_rem_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_x(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) # else: # #assert na_c1e == nb_c1e @@ -1621,11 +1624,11 @@ def emit_kernel_prologue(self): # if self.tunable.nxe != 0: # # assert s.s_out_stride_k.label not in self.dict_shifted_stride and s.s_wei_stride_k.label not in self.dict_shifted_stride - # if s.s_in_stride_c.label not in self.dict_shifted_stride: - # self._emit(m_move_slice_window_tb.init_stride_c(s.s_in_stride_c(), s.s_in_stride_c_c1(), + # if s.s_stride_c.label not in self.dict_shifted_stride: + # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_c(), s.s_in_stride_c_c1(), # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) # else: - # self._emit(f"s_lshr_b32 s[{s.s_tmp(3)}], s[{s.s_in_stride_c()}], {utility_log2(data_byte)}") + # self._emit(f"s_lshr_b32 s[{s.s_tmp(3)}], s[{s.s_stride_c()}], {utility_log2(data_byte)}") # self._emit(m_move_slice_window_tb.init_stride_c(s.s_tmp(3), s.s_in_stride_c_c1(), # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) # else: @@ -1639,7 +1642,7 @@ def emit_kernel_prologue(self): # if not self.is_1d_move_slice_k(): # self._emit(f"s_mov_b32 s[{s.s_gemm_k_num_c1()}], {unmerge_sub_tb_c1}") #if self.tunable.nxe != 0: - # self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_wei_stride_c()}], s[{s.s_c()}]") + # self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_stride_c()}], s[{s.s_c()}]") #else: # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") self._emit_empty_line() @@ -1650,11 +1653,11 @@ def emit_kernel_prologue(self): #self._emit(self.try_shift_stride(s.s_wei_stride_k_k0_k1_diff, igemm_log2(data_byte))) if self.tunable.nxe != 0: - self._emit(self.try_shift_stride(s.s_in_stride_c, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) self._emit(self.try_shift_stride(s.s_wei_stride_k, igemm_log2(data_byte))) # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) else: - self._emit(self.try_shift_stride(s.s_in_stride_c, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) self._emit(self.try_shift_stride(s.s_c, igemm_log2(data_byte))) # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) @@ -1794,10 +1797,10 @@ def emit_kernel_epilogue(self): if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if self.tunable.nxe != 0: # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, - # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_in_stride_c(), s.s_tmp(), v.v_in_flag())) + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_stride_c(), s.s_tmp(), v.v_in_flag())) # else: # self._emit(self.coalescing_store(v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_in(), v.v_in_os(), None, - # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_in_stride_c(), s.s_tmp())) + # s.s_in_stride_c0() if self.tunable.gemm_m_unmerge_cluster == 1 else None, s.s_stride_c(), s.s_tmp())) assert False else: a = self.agpr From b72e0a1e98354123ebe1dc616e7fd6fd276f643f Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 19 Jan 2021 08:48:26 +0800 Subject: [PATCH 05/40] remove several code --- igemm/algo/igemm_fwd_gtc_nhwc.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index a98723c4..12d8a5ec 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -1438,19 +1438,12 @@ def emit_kernel_prologue(self): self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in1b()}]") else: self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_n1b)}, v[{v.v_gtc_ta_in1b()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_n0*na_n1b)}, v[{v.v_tmp()}]") - #if cb_c0 != 1: - # self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_tb_ic0()}], {igemm_log2(nb_c1e*nb_n0*nb_n1b)}, v[{v.v_tmp()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") - # self._emit(f"v_add_u32 v[{v.v_sst_a_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_a_os()}]") self._emit_empty_line() self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k}, {ca_e}x{ca_c}x{cb_k}") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(nb_k)}, v[{v.v_gtc_tb_ik()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") self._emit_empty_line() @@ -1502,17 +1495,6 @@ def emit_kernel_prologue(self): assert False, "un implemented, should rarely be used" else: assert False - # if na_n0 != 1: - # self._emit(f"v_and_b32 v[{v.v_out_in0()}], {na_n0 - 1}, v[{v.v_co_sub_m_index()}] ; => N0") - # if na_n1b != 1: - # self._emit(f"v_lshrrev_b32 v[{v.v_out_in1b()}], {igemm_log2(na_n0)}, v[{v.v_co_sub_m_index()}] ; => N1B") - # else: - # assert False, "un implemented, should rarely be used" - # else: - # if na_n1b != 1: - # self._emit(f"v_mov_b32 v[{v.v_out_in1b()}], v[{v.v_co_sub_m_index()}] ; => N1B") - # else: - # assert False, "un implemented, should rarely be used" # TODO: extend tensor size, here vgpr only have 32bit self._emit(f"; compute from n1b") From cacc8efbe82cff7696f509e45bd473da4c64eaed Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 19 Jan 2021 23:09:56 +0800 Subject: [PATCH 06/40] k_pack for smem --- config/igemm_fwd_gtc_gfx908_nhwc.config | 16 +- igemm/algo/igemm_fwd_gtc_nhwc.py | 339 +++++++++--------------- igemm/algo/shared_memory.py | 83 ++++++ 3 files changed, 216 insertions(+), 222 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 2c664ca3..25b44581 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -14,10 +14,10 @@ wave_repeat_m = 2 wave_tile_n = 32 wave_step_n = 1 wave_repeat_n = 2 -tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxN0xN1B -tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxN0xN1B -tensor_b_thread_lengths = [1, 4, 2] # ExCxK -tensor_b_cluster_lengths = [1, 4, 64] # ExCxK +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' @@ -35,10 +35,10 @@ wave_repeat_m = 2 wave_tile_n = 32 wave_step_n = 1 wave_repeat_n = 2 -tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxN0xN1B -tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxN0xN1B -tensor_b_thread_lengths = [1, 4, 2] # ExCxK -tensor_b_cluster_lengths = [1, 4, 64] # ExCxK +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 12d8a5ec..f6d7144c 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -35,13 +35,6 @@ from .coalescing_store import * from .mfma_main_loop import * - -IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B = 0 -IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 = 1 -# IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_N_E_K = 4 -# IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_N_K_E = 5 - - def _find_non_1_index_in_list(list_object): result_list = list() for idx, item in enumerate(list_object): @@ -53,8 +46,8 @@ def _find_non_1_index_in_list(list_object): class igemm_fwd_gtc_nhwc_t(mc_base_t): ''' tensor a (input) tensor b (wei) - thread_lengths : ta_e, ta_c, ta_n0, ta_n1b, tb_e, tb_c, tb_k - cluster_lengths : ca_e, ca_c, ca_n0, ca_n1b, cb_e, cb_c, cb_k + thread_lengths : ta_e, ta_c, ta_nb0, ta_nb1, tb_e, tb_c, tb_k0, tb_k1 + cluster_lengths : ca_e, ca_c, ca_nb0, ca_nb1, cb_e, cb_c, cb_k0, cb_k1 for a/b tensor, always load gemm_k dimension first. @@ -121,12 +114,12 @@ def flatten(x): ctrl_coalescing_store_xdlops.vector_write_out = 1 # TODO: some cases this can be set to other value ctrl_coalescing_store_xdlops.block_size = self.tunable.block_size - gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() - na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() - ctrl_coalescing_store_xdlops.gemm_m_m0_m1 = [na_n0, na_n1b] - if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0: - # we may consider not suppor this mode - ctrl_coalescing_store_xdlops.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 + # gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() + ctrl_coalescing_store_xdlops.gemm_m_m0_m1 = [na_nb0, na_nb1] + #if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0: + # # we may consider not suppor this mode + # ctrl_coalescing_store_xdlops.gemm_m_order = IGEMM_COALESCING_GEMM_M_ORDER_M1_M0 ctrl_coalescing_store_xdlops.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust self.coalescing_store = igemm_coalescing_store_xdlops_t(mc, ctrl_coalescing_store_xdlops) @@ -150,7 +143,9 @@ def try_shift_stride(self, gpr, shifter): self.dict_shifted_stride[gpr.label] = gpr self._emit(f"s_lshl_b32 s[{gpr()}], s[{gpr()}], {shifter}") return self._get_deferred() - + + # will not support order, since nhwc fix order is enough + ''' def get_lds_gemm_m_gemm_n_order(self): def need_reverse_order(x0, x1): if x0 != 1 and x1 == 1: @@ -159,17 +154,18 @@ def need_reverse_order(x0, x1): return True return False - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() gemm_n_order = -1 # gemm_n order is not supported gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B if self.tunable.allow_lds_reorder: - if need_reverse_order(ta_n0, ta_n1b): + if need_reverse_order(ta_nb0, ta_nb1): gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 assert False, "maybe not correct" return gemm_m_order, gemm_n_order + ''' class macro_set_flag_hw(macro_base_t): def __init__(self, mc, inline = False): @@ -334,7 +330,7 @@ def __call__(self): # ''' # this now only meaning for input tensor # ''' - # na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + # na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() # if self.tunable.nxe != 0: # return False # if not nxe 0, it is possible that we can do move slice, but that will lead to extra index calculation # if nb_c1e != 1 and nb_c0 == 1: @@ -446,7 +442,7 @@ def emit(self): class kernel_sgpr_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) - ta_n0, ta_n1b, ta_e, ta_c, tb_k = outer.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = outer.get_thread_lengths() sseq = gpr_sequencer_t() self.outer = outer self.s_ka = sym_t('s_ka' , sseq(2)) @@ -476,10 +472,12 @@ def __init__(self, mc, outer): self.s_in_stride_hi = sym_t('s_in_stride_hi' , sseq(1)) self.s_in_stride_wi = sym_t('s_in_stride_wi' , sseq(1)) self.s_in_stride_n = sym_t('s_in_stride_n' , sseq(1)) - if ta_n0 != 1: + if ta_nb0 != 1: self.s_in_stride_n0 = sym_t('s_in_stride_n0' , sseq(1)) # stride for wei + if tb_k0 != 1: + self.s_wei_stride_k0 = sym_t('s_wei_stride_k0' , sseq(1)) self.s_wei_stride_k = sym_t('s_wei_stride_k' , sseq(1)) if outer.tunable.nxe != 0: self.s_wei_stride_y = sym_t('s_wei_stride_y' , sseq(1)) @@ -488,7 +486,7 @@ def __init__(self, mc, outer): # stride for out self.s_out_stride_wo = sym_t('s_out_stride_wo' , sseq(1)) self.s_out_stride_n = sym_t('s_out_stride_n' , sseq(1)) - if ta_n0 != 1: + if ta_nb0 != 1: self.s_out_stride_n0 = sym_t('s_out_stride_n0' , sseq(1)) self.s_in_stride_c_c1 = sym_t("s_in_stride_c_c1" , sseq(1)) @@ -566,8 +564,8 @@ class kernel_vgpr_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) self.outer = outer - ta_n0, ta_n1b, ta_e, ta_c, tb_k = outer.get_thread_lengths() - ca_n0, ca_n1b, ca_e, ca_c, cb_k = outer.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = outer.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = outer.get_cluster_lengths() is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS vseq = gpr_sequencer_t() @@ -599,11 +597,13 @@ def __init__(self, mc, outer): self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) self.v_gtc_ta_ic = sym_t("v_gtc_ta_ic" ,vseq(1)) - if ca_n0 != 1: + if ca_nb0 != 1: self.v_gtc_ta_in0 = sym_t("v_gtc_ta_in0" ,vseq(1)) self.v_gtc_ta_in1b = sym_t("v_gtc_ta_in1b" ,vseq(1)) self.v_gtc_ta_in1 = sym_t("v_gtc_ta_in1" ,vseq(1)) + if tb_k0 != 1: + self.v_gtc_tb_ik0 = sym_t("v_gtc_tb_ik0" ,vseq(1)) self.v_gtc_tb_ik = sym_t("v_gtc_tb_ik" ,vseq(1)) self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) @@ -675,10 +675,10 @@ def get_thread_lengths(self): t_ta = self.tunable.tensor_a_thread_lengths t_tb = self.tunable.tensor_b_thread_lengths - assert len(t_ta) == 4 and len(t_tb) == 3 + assert len(t_ta) == 4 and len(t_tb) == 4 - ta_e, ta_c, ta_n0, ta_n1b = t_ta[0], t_ta[1], t_ta[2], t_ta[3] - tb_e, tb_c, tb_k = t_tb[0], t_tb[1], t_tb[2] + ta_e, ta_c, ta_nb0, ta_nb1 = t_ta[0], t_ta[1], t_ta[2], t_ta[3] + tb_e, tb_c, tb_k0, tb_k1 = t_tb[0], t_tb[1], t_tb[2], t_tb[3] assert ta_e == tb_e and ta_c == tb_c @@ -691,36 +691,36 @@ def get_thread_lengths(self): else: pass - return ta_n0, ta_n1b, ta_e, ta_c, tb_k # M, N, K + return ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 # M, K, N def get_cluster_lengths(self): c_ta = self.tunable.tensor_a_cluster_lengths c_tb = self.tunable.tensor_b_cluster_lengths - assert len(c_ta) == 4 and len(c_tb) == 3 + assert len(c_ta) == 4 and len(c_tb) == 4 - ca_e, ca_c, ca_n0, ca_n1b = c_ta[0], c_ta[1], c_ta[2], c_ta[3] - cb_e, cb_c, cb_k = c_tb[0], c_tb[1], c_tb[2] + ca_e, ca_c, ca_nb0, ca_nb1 = c_ta[0], c_ta[1], c_ta[2], c_ta[3] + cb_e, cb_c, cb_k0, cb_k1 = c_tb[0], c_tb[1], c_tb[2], c_tb[3] assert ca_e == cb_e and ca_c == cb_c - assert ca_e == 1 and ca_n0 == 1 + assert ca_e == 1 and ca_nb0 == 1 and cb_k0 == 1 - return ca_n0, ca_n1b, ca_e, ca_c, cb_k # M, N, K + return ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 # M, K, N def get_dims_lengths(self): - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() - ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() - na_n0, na_n1b, na_e, na_c = ta_n0 * ca_n0, ta_n1b * ca_n1b, ta_e * ca_e, ta_c * ca_c - nb_k = tb_k * cb_k + na_nb0, na_nb1, na_e, na_c = ta_nb0 * ca_nb0, ta_nb1 * ca_nb1, ta_e * ca_e, ta_c * ca_c + nb_k0, nb_k1 = tb_k0 * cb_k0, tb_k1 * cb_k1 - return na_n0, na_n1b, na_e, na_c, nb_k # M, N, K + return na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 # M, K, N def get_thread_copy_dims(self): - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() - in_thread_copy_dims = [ta_n0, ta_n1b, ta_e, ta_c] - wei_thread_copy_dims = [tb_k, ta_e, ta_c] # always reordered! + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + in_thread_copy_dims = [ta_nb0, ta_nb1, ta_e, ta_c] + wei_thread_copy_dims = [tb_k0, tb_k1, ta_e, ta_c] # always reordered! return in_thread_copy_dims, wei_thread_copy_dims def get_thread_copy_index(self): @@ -739,8 +739,8 @@ def get_macro_global_load(self): NOTICE: input/wei always load gemm_k (e*c) first. indeed always load c, and do vector load if possible ''' inline = True if self.tunable.fma_interleave else False - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() - na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() @@ -778,121 +778,32 @@ def get_macro_global_load(self): def get_macro_shared_store(self): - in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() - in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() - na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + #in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() + #in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() data_byte = amdgpu_precision_data_byte(self.tunable.precision) - gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + k_pack = ta_c # always use this as k_pack - ## give the LDS strides of wei dimensions [ta_k0, ta_k1, ta_c0, ta_c1e] - #if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K0_K1: - # wei_stride_list = [na_k1, 1, na_c1e*na_k0*na_k1, na_k0*na_k1] - #else: - # wei_stride_list = [1, na_k0, na_c1e*na_k0*na_k1, na_k0*na_k1] - - ## give the LDS strides of in dimensions [tb_c0, tb_c1e, tb_n0, tb_n1b] - #if gemm_n_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_N_N0_N1B: - # in_stride_list = [nb_c1e*nb_n0*nb_n1b, nb_n0*nb_n1b, nb_n1b, 1] - #else: - # in_stride_list = [nb_c1e*nb_n0*nb_n1b, nb_n0*nb_n1b, 1, nb_n0] - - # [ta_n0, ta_n1b, ta_e, ta_c] - if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: - in_stride_list = [na_n1b, 1, na_c*na_n0*na_n1b, na_n0*na_n1b] - else: - in_stride_list = [1, na_n0, na_c*na_n0*na_n1b, na_n0*na_n1b] - - # [tb_k, ta_e, ta_c] - wei_stride_list = [1, nb_k*na_c, nb_k] - - in_sst_ctrl = ctrl_2d_shared_store_t() - in_sst_ctrl.src_order = 1 - in_sst_ctrl.v_tmp = self.vgpr.v_tmp + # input is gemm_k * gemm_m * k_pack + in_sst_ctrl = ctrl_3d_shared_store_t() + in_sst_ctrl.length_d0 = ta_nb0 + in_sst_ctrl.length_d1 = ta_nb1 + in_sst_ctrl.length_dp = k_pack + in_sst_ctrl.stride_d0 = na_nb1 * k_pack * data_byte + in_sst_ctrl.stride_d1 = k_pack * data_byte - wei_sst_ctrl = ctrl_2d_shared_store_t() - wei_sst_ctrl.src_order = 1 - wei_sst_ctrl.v_tmp = self.vgpr.v_tmp + # wei is gemm_k * gemm_n * k_pack + wei_sst_ctrl = ctrl_3d_shared_store_t() + wei_sst_ctrl.length_d0 = tb_k0 + wei_sst_ctrl.length_d1 = tb_k1 + wei_sst_ctrl.length_dp = k_pack + wei_sst_ctrl.stride_d0 = nb_k1 * k_pack * data_byte + wei_sst_ctrl.stride_d1 = k_pack * data_byte - # [ta_n0, ta_n1b, ta_e, ta_c] - if self.in_thread_copy_ndim == 2: - if in_thread_copy_index[0] in (0, 1) and in_thread_copy_index[1] in (2, 3): - in_sst_ctrl.length_d0 = in_thread_copy_dims[in_thread_copy_index[1]] - in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] - in_sst_ctrl.stride_d0 = in_stride_list[in_thread_copy_index[1]] * data_byte - in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[0]] * data_byte - else: - in_sst_ctrl.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] - in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] - in_sst_ctrl.stride_d0 = in_stride_list[in_thread_copy_index[0]] * data_byte - in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[1]] * data_byte - if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: - in_sst_ctrl.vector_d1 = ta_n1b - else: - in_sst_ctrl.vector_d1 = in_thread_copy_dims[in_thread_copy_index[1]] - - elif self.in_thread_copy_ndim == 1: - in_sst_ctrl.length_d0 = 1 - in_sst_ctrl.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] - if (gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B and ta_n1b != 1) or \ - (gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 and ta_n0 != 1): - in_sst_ctrl.vector_d1 = in_thread_copy_dims[in_thread_copy_index[0]] - else: - in_sst_ctrl.vector_d1 = 1 - in_sst_ctrl.stride_d0 = 1 - in_sst_ctrl.stride_d1 = in_stride_list[in_thread_copy_index[0]] * data_byte - if in_sst_ctrl.length_d1 == 8 and in_sst_ctrl.vector_d1 != 1: - # assert False - # TODO: this is indeed not optimal. may consider shuffle in the future. - in_sst_ctrl.length_d0 = 2 - in_sst_ctrl.length_d1 = 4 - in_sst_ctrl.vector_d1 = 4 - in_sst_ctrl.stride_d0 = 4 * data_byte - else: - assert False - - # [tb_k, ta_e, ta_c] - if self.wei_thread_copy_ndim == 2: - if wei_thread_copy_index[0] in (0,) and wei_thread_copy_index[1] in (1, 2): - # when store into LDS, reorder back. indeed we always wish this pattern, if ndim is 2 - wei_sst_ctrl.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[1]] - wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] - wei_sst_ctrl.stride_d0 = wei_stride_list[wei_thread_copy_index[1]] * data_byte - wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[0]] * data_byte - else: - wei_sst_ctrl.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] - wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] - wei_sst_ctrl.stride_d0 = wei_stride_list[wei_thread_copy_index[0]] * data_byte - wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[1]] * data_byte - wei_sst_ctrl.need_transpose = 0 - wei_sst_ctrl.vector_d1 = tb_k - - elif self.wei_thread_copy_ndim == 1: - wei_sst_ctrl.length_d0 = 1 - wei_sst_ctrl.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] - - if ta_c != 1: - wei_sst_ctrl.vector_d1 = utility_gcd(wei_thread_copy_dims[wei_thread_copy_index[0]], 4) - else: - wei_sst_ctrl.vector_d1 = 1 - - wei_sst_ctrl.stride_d0 = 1 - wei_sst_ctrl.stride_d1 = wei_stride_list[wei_thread_copy_index[0]] * data_byte - if wei_sst_ctrl.length_d1 == 8 and wei_sst_ctrl.vector_d1 != 1: - # assert False - # TODO: this is indeed not optimal. may consider shuffle in the future. - wei_sst_ctrl.length_d0 = 2 - wei_sst_ctrl.length_d1 = 4 - wei_sst_ctrl.vector_d1 = 4 - wei_sst_ctrl.stride_d0 = 4 * data_byte - else: - assert False - - # print(f"in_sst_ctrl.vector_d1:{in_sst_ctrl.vector_d1}, wei_sst_ctrl.vector_d1:{wei_sst_ctrl.vector_d1}") - # print(f"wei_sst_ctrl, {wei_sst_ctrl.serialize()}") inline = True if self.tunable.fma_interleave else False - return macro_igemm_2d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_2d_shared_store_t(self.mc, wei_sst_ctrl, inline) + return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_2d_shared_store_t(self.mc, wei_sst_ctrl, inline) # computation macro def get_macro_in_update_hw(self): @@ -918,20 +829,21 @@ def get_macro_set_flag_hw(self): return self.macro_set_flag_hw(self.mc, inline) def get_symbol_global_load_s_stride_d0_d1(self): - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() # get the symbol object that load 2d may use s = self.sgpr s_dummy = sym_t("s_dummy") in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() - # [ta_n0, ta_n1b, ta_e, ta_c] - in_stride_gprs = [s.s_in_stride_n0 if ta_n0 != 1 else s_dummy, - s.s_in_stride_wi, + # [ta_nb0, ta_nb1, ta_e, ta_c] + in_stride_gprs = [s_dummy, # complecated + s_dummy, # complecated s_dummy, s.s_stride_c] - # [tb_k, ta_e, ta_c] - wei_stride_gprs = [s.s_wei_stride_k, + # [tb_k0, tb_k1, ta_e, ta_c] + wei_stride_gprs = [s.s_wei_stride_k0 if tb_k0 != 1 else s_dummy, + s.s_wei_stride_k if tb_k1 != 1 else s_dummy, s_dummy, s.s_stride_c] @@ -1076,24 +988,23 @@ def emit_kernel_prologue(self): s = self.sgpr v = self.vgpr k = self.karg - gemm_m_unmerge_cluster = self.tunable.gemm_m_unmerge_cluster - gemm_n_unmerge_cluster = self.tunable.gemm_n_unmerge_cluster - gemm_k_unmerge_cluster = self.tunable.gemm_k_unmerge_cluster - - assert gemm_m_unmerge_cluster == 0 and gemm_n_unmerge_cluster == 0 and gemm_k_unmerge_cluster == 0, 'in fwd nhwc, gemm_m/n/k unmerge_cluster only support 0' + # gemm_m_unmerge_cluster = self.tunable.gemm_m_unmerge_cluster + # gemm_n_unmerge_cluster = self.tunable.gemm_n_unmerge_cluster + # gemm_k_unmerge_cluster = self.tunable.gemm_k_unmerge_cluster - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() - ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() - na_n0, na_n1b, na_e, na_c, nb_k = self.get_dims_lengths() + # assert gemm_m_unmerge_cluster == 0 and gemm_n_unmerge_cluster == 0 and gemm_k_unmerge_cluster == 0, 'in fwd nhwc, gemm_m/n/k unmerge_cluster only support 0' - unmerge_sub_n = self.tunable.unmerge_sub_n - if gemm_n_unmerge_cluster == 0: - assert unmerge_sub_n % na_n0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, na_n0:{na_n0}" - unmerge_sub_n1 = unmerge_sub_n // na_n0 - assert na_n1b % unmerge_sub_n1 == 0, f"na_n1b:{na_n1b}, unmerge_sub_n1:{unmerge_sub_n1}" + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() + na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() - else: - assert False, f"unsupported gemm_n_unmerge_cluster:{self.tunable.gemm_n_unmerge_cluster}" + # unmerge_sub_n = self.tunable.unmerge_sub_n + # if gemm_n_unmerge_cluster == 0: + # assert unmerge_sub_n % na_nb0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, na_nb0:{na_nb0}" + # unmerge_sub_n1 = unmerge_sub_n // na_nb0 + # assert na_nb1 % unmerge_sub_n1 == 0, f"na_nb1:{na_nb1}, unmerge_sub_n1:{unmerge_sub_n1}" + # else: + # assert False, f"unsupported gemm_n_unmerge_cluster:{self.tunable.gemm_n_unmerge_cluster}" data_byte = amdgpu_precision_data_byte(self.tunable.precision) @@ -1121,7 +1032,7 @@ def emit_kernel_prologue(self): m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) - gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() + # gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() s_dummy = sym_t("s_dummy") # start emit @@ -1143,17 +1054,17 @@ def emit_kernel_prologue(self): self._emit(f"s_load_dword s[{s.s_magic_6()}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_6()}") self._emit(f"s_load_dwordx2 s[{s.s_shift_pack_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") - self._emit(f"; in(e, c, n0, n1b) thread_lengths: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, cluster_length: {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}, unmerge_sub_n:{unmerge_sub_n}, unmerge_sub_n1:{unmerge_sub_n1}") + self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, unmerge_sub_n:{unmerge_sub_n}, unmerge_sub_n1:{unmerge_sub_n1}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) - if ca_n0 != 1: + if ca_nb0 != 1: # TODO: this is not wanted - self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_n1b, ta_n1b)) - self._emit(tc_index_dispatcher(v.v_gtc_ta_in0(), v.v_tmp(), ca_n0, ta_n0, True)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_nb1, ta_nb1)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_in0(), v.v_tmp(), ca_nb0, ta_nb0, True)) else: - self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_n1b, ta_n1b, True)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_nb1, ta_nb1, True)) - self._emit(f"; wei(e, c, k) thread_length: {ta_e}x{ta_c}x{tb_k}, cluster_length: {ca_e}x{ca_c}x{cb_k}") + self._emit(f"; wei(e, c, k0, k1) thread_length: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, cluster_length: {ca_e}x{ca_c}x{cb_k0}x{cb_k1}") # weight ic same as input self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") self._emit(tc_index_dispatcher(v.v_gtc_tb_ik(), v.v_tmp(), cb_k, tb_k, True)) @@ -1181,7 +1092,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") - if ta_n0 != 1: + if ta_nb0 != 1: self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") # weight self._emit(f"s_mul_i32 s[{s.s_wei_stride_y()}], s[{s.s_x()}], s[{s.s_c()}]") @@ -1190,7 +1101,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo()}], s[{s.s_out_stride_wo()}]") self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho()}], s[{s.s_tmp(1)}]") - if ta_n0 != 1: + if ta_nb0 != 1: self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") else: @@ -1198,7 +1109,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") - if ta_n0 != 1: + if ta_nb0 != 1: self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") # weight self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") @@ -1206,7 +1117,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wi()}], s[{s.s_out_stride_wo()}]") self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(1)}]") - if ta_n0 != 1: + if ta_nb0 != 1: self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") # early init s_knum in case shifted @@ -1270,15 +1181,15 @@ def emit_kernel_prologue(self): self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") - # to compute ho*wo*sub_n1 // na_n1b + # to compute ho*wo*sub_n1 // na_nb1 if unmerge_sub_n1 == 1: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_n1b)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1)} ; total number of nb1") else: - assert na_n1b >= unmerge_sub_n1 - if unmerge_sub_n1 == na_n1b: - self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of n1b") + assert na_nb1 >= unmerge_sub_n1 + if unmerge_sub_n1 == na_nb1: + self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of nb1") else: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_n1b // unmerge_sub_n1)} ; total number of n1b") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1 // unmerge_sub_n1)} ; total number of nb1") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") @@ -1286,14 +1197,14 @@ def emit_kernel_prologue(self): else: self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) - if na_n1b != 1: - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_n1b)}") - if na_n0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_n0)}") + if na_nb1 != 1: + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_nb1)}") + if na_nb0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_nb0)}") self._emit_empty_line() - self._emit(f"; in n1b transform") - if ca_n1b == 1: + self._emit(f"; in nb1 transform") + if ca_nb1 == 1: self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}]") else: self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_ta_in1b()}]") @@ -1339,8 +1250,8 @@ def emit_kernel_prologue(self): self._emit_empty_line() #if gemm_m_unmerge_cluster == 0: - if ca_n0 != 1: - self._emit(tc_index_accumulator(v.v_tmp(1), v.v_gtc_ta_in0(), v.v_gtc_ta_in1(), ca_n0, ca_n1b, 0, unmerge_sub_n1)) + if ca_nb0 != 1: + self._emit(tc_index_accumulator(v.v_tmp(1), v.v_gtc_ta_in0(), v.v_gtc_ta_in1(), ca_nb0, ca_nb1, 0, unmerge_sub_n1)) self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_tmp(1)}]") else: self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") @@ -1429,16 +1340,16 @@ def emit_kernel_prologue(self): self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) - self._emit(f"; LDS store, in: e,c,n0,n1b: {ta_e}x{ta_c}x{ta_n0}x{ta_n1b}, {ca_e}x{ca_c}x{ca_n0}x{ca_n1b}") - if ca_n1b == 1: + self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") + if ca_nb1 == 1: # TODO: remove this path, not possible go here assert False else: - if ca_n0 == 1: + if ca_nb0 == 1: self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in1b()}]") else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_n1b)}, v[{v.v_gtc_ta_in1b()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_n0*na_n1b)}, v[{v.v_tmp()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_gtc_ta_in1b()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_nb0*na_nb1)}, v[{v.v_tmp()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit_empty_line() @@ -1485,19 +1396,19 @@ def emit_kernel_prologue(self): self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out()}+1], 0") self._emit_empty_line() - self._emit(f"; compute v_co_sub_m_index along n0 x n1b : {na_n0}x{na_n1b}") + self._emit(f"; compute v_co_sub_m_index along nb0 x nb1 : {na_nb0}x{na_nb1}") if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: - if na_n1b != 1: - self._emit(f"v_and_b32 v[{v.v_out_in1b()}], {na_n1b - 1}, v[{v.v_co_sub_m_index()}] ; => N1B") - if na_n0 != 1: - self._emit(f"v_lshrrev_b32 v[{v.v_out_in0()}], {igemm_log2(na_n1b)}, v[{v.v_co_sub_m_index()}] ; => N0") + if na_nb1 != 1: + self._emit(f"v_and_b32 v[{v.v_out_in1b()}], {na_nb1 - 1}, v[{v.v_co_sub_m_index()}] ; => N1B") + if na_nb0 != 1: + self._emit(f"v_lshrrev_b32 v[{v.v_out_in0()}], {igemm_log2(na_nb1)}, v[{v.v_co_sub_m_index()}] ; => N0") else: assert False, "un implemented, should rarely be used" else: assert False # TODO: extend tensor size, here vgpr only have 32bit - self._emit(f"; compute from n1b") + self._emit(f"; compute from nb1") self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_out_in1b()}]") if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: @@ -1521,7 +1432,7 @@ def emit_kernel_prologue(self): self._emit_empty_line() self._emit_empty_line() self._emit(f"; add in_in0, in_in1") - if na_n0 != 1: + if na_nb0 != 1: #if gemm_m_unmerge_cluster == 0: self._emit(f"v_lshl_or_b32 v[{v.v_tmp(1)}], v[{v.v_out_in0()}], {igemm_log2(unmerge_sub_n1)}, v[{v.v_out_in1()}]") self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_n()}], v[{v.v_tmp(1)}]") @@ -1773,8 +1684,8 @@ def emit_kernel_epilogue(self): v = self.vgpr #label_out = f"L_{self.name()}_out" - ta_n0, ta_n1b, ta_e, ta_c, tb_k = self.get_thread_lengths() - ca_n0, ca_n1b, ca_e, ca_c, cb_k = self.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if self.tunable.nxe != 0: @@ -1787,7 +1698,7 @@ def emit_kernel_epilogue(self): else: a = self.agpr self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, - s.s_out_stride_n0() if ta_n0 != 1 else None, s.s_out_stride_wo(), + s.s_out_stride_n0() if ta_nb0 != 1 else None, s.s_out_stride_wo(), s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_k(), v.v_cur_k(), s.s_block_gtc_ik(), v.v_co_sub_m_index(), v.v_tmp())) self._emit_front(f"{self.label_out}:") diff --git a/igemm/algo/shared_memory.py b/igemm/algo/shared_memory.py index 7b55cffe..58d7849b 100644 --- a/igemm/algo/shared_memory.py +++ b/igemm/algo/shared_memory.py @@ -844,3 +844,86 @@ def get_issues(self): with self._deferred_context(): self.emit() return self.issue_cnt + +class ctrl_3d_shared_store_t(object): + ''' + d0 x d1 x dp (d pack) + ''' + def __init__(self): + self.length_d0 = 1 # is d0 is 1, it is indeed 1d access + self.length_d1 = 1 + self.length_dp = 1 + self.stride_d0 = 1 # stride + self.stride_d1 = 1 # if have stride_d1, then each d1 may have stride + self.precision = 'fp32' # 'fp32', 'fp16', ... + self.src_order = 0 # 0-d0,d1, 1-d1,d0 + self.need_transpose = 1 + self.v_tmp = None # used when order is 1 and consider shuffle + + def serialize(self): + return f"length_d0:{self.length_d0}, length_d1:{self.length_d1}, length_dp:{self.length_dp}, stride_d0:{self.stride_d0}, stride_d1:{self.stride_d1}, precision:{self.precision}, src_order:{self.src_order}" + +class macro_igemm_3d_shared_store_t(macro_base_t): + ''' + this is indeed for + 0: gemm_k * gemm_m/n * k_pack + 1: gemm_m/n * gemm_k * k_pack + we always want to use k_pack as vector store + ''' + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_3d_shared_store_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.issue_cnt = 0 + self.declare_arg("v_src") + self.declare_arg("v_sst_os") + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + return f".v_sst_so{ctrl.src_order}_{ctrl.length_d0}x{ctrl.length_d1}x{ctrl.length_dp}_{bits_str}" + \ + f"_st{ctrl.stride_d0}x{ctrl.stride_d1}" + + def expr(self): + ctrl = self.ctrl + assert ctrl.precision == 'fp32', "TO BE supported" + data_byte = amdgpu_precision_data_byte(ctrl.precision) + issue_cnt = 0 + + if ctrl.length_d0 == 1 or ctrl.length_d1 == 1: + # this is indeed a 2d case. + + if ctrl.length_d0 == 1 and ctrl.length_d1 == 1: + # further, 1d case + ds_write = inst_ds_write_t(ctrl.length_dp * data_byte) + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}')) + issue_cnt += ds_write.get_issues() + + else: + length_d = ctrl.length_d0 if ctrl.length_d0 != 1 else ctrl.length_d1 + stride_d = ctrl.stride_d0 if ctrl.length_d0 != 1 else ctrl.stride_d1 + if length_d % 2 == 0 and data_byte == 4 and ctrl.length_dp in (1, 2): + ds_write2 = inst_ds_write2_likely_t(self.mc, 2, ctrl.length_dp * data_byte, stride_d) + for i_d in range(length_d // 2): + self._emit(ds_write2(f'{self.v_sst_os()}', f'{self.v_src()}+{2 * i_d*ctrl.length_dp}', 2 * i_d * stride_d)) + issue_cnt += ds_write2.get_issues(2 * i_d * stride_d) + else: + for i_d in range(length_d): + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{s_id}', i_d0 * ctrl.stride_d0)) + issue_cnt += ds_write.get_issues() + else: + assert False, "un implemented yet" + + self.issue_cnt = issue_cnt + + def get_issues(self): + #assert False, "tobe implemented" + #return self.ctrl.length_d0 + with self._deferred_context(): + self.emit() + return self.issue_cnt From 88a3c36bbe1b181ae43e61413ce037940b9cd3ed Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 21 Jan 2021 20:57:58 +0800 Subject: [PATCH 07/40] more add --- igemm/algo/igemm_fwd_gtc_nhwc.py | 163 ++++++++++++++++++++++--------- 1 file changed, 117 insertions(+), 46 deletions(-) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index f6d7144c..5b210ee7 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -521,6 +521,8 @@ def __init__(self, mc, outer): #if outer.tunable.nxe != 0: self.s_dim_b = sym_t("s_dim_b" , sseq(1)) + self.s_dim_m = sym_t("s_dim_m" , sseq(1)) + self.s_dim_n = sym_t("s_dim_n" , sseq(1)) self.s_kitr = sym_t("s_kitr" , 1) if outer.tunable.precache_soffset: @@ -529,7 +531,7 @@ def __init__(self, mc, outer): wei_npc = m_wei_2d_global_load.get_num_precache_soffset() self.s_in_offset = sym_t("s_in_offset" ,sseq(in_npc)) # if this number is zero, it is also OK, since we would not use self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) - self.s_k_padded = sym_t("s_k_padded" ,sseq(1)) + # self.s_k_padded = sym_t("s_k_padded" ,sseq(1)) # TODO: this sgpr allocation is a mess if IGEMM_GTC_FEAT_MAGIC_DIVISION: @@ -597,9 +599,9 @@ def __init__(self, mc, outer): self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) self.v_gtc_ta_ic = sym_t("v_gtc_ta_ic" ,vseq(1)) - if ca_nb0 != 1: - self.v_gtc_ta_in0 = sym_t("v_gtc_ta_in0" ,vseq(1)) - self.v_gtc_ta_in1b = sym_t("v_gtc_ta_in1b" ,vseq(1)) + #if ca_nb0 != 1: + # self.v_gtc_ta_in0 = sym_t("v_gtc_ta_in0" ,vseq(1)) + self.v_gtc_ta_inb1 = sym_t("v_gtc_ta_inb1" ,vseq(1)) self.v_gtc_ta_in1 = sym_t("v_gtc_ta_in1" ,vseq(1)) if tb_k0 != 1: @@ -1056,13 +1058,8 @@ def emit_kernel_prologue(self): self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, unmerge_sub_n:{unmerge_sub_n}, unmerge_sub_n1:{unmerge_sub_n1}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") - self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) - if ca_nb0 != 1: - # TODO: this is not wanted - self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_nb1, ta_nb1)) - self._emit(tc_index_dispatcher(v.v_gtc_ta_in0(), v.v_tmp(), ca_nb0, ta_nb0, True)) - else: - self._emit(tc_index_dispatcher(v.v_gtc_ta_in1b(), v.v_tmp(), ca_nb1, ta_nb1, True)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) + self._emit(tc_index_dispatcher(v.v_gtc_ta_inb1(), v.v_tmp(), ca_nb1, ta_nb1, True)) self._emit(f"; wei(e, c, k0, k1) thread_length: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, cluster_length: {ca_e}x{ca_c}x{cb_k0}x{cb_k1}") # weight ic same as input @@ -1121,19 +1118,34 @@ def emit_kernel_prologue(self): self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") # early init s_knum in case shifted - #if self.tunable.nxe != 0: self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}]") - #else: - # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") - # warp around the really dim_b length, in case pad + # pad gemm_m, gemm_n if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_ho()}], s[{s.s_wo()}]") - self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_tmp(4)}]") - self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") - self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_ho()}], s[{s.s_wo()}]") else: - self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") # no pad + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_hi()}], s[{s.s_wi()}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_n()}], s[{s.s_tmp(1)}]") + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_tmp(2)}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_m()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_n_per_block - 1}, s[{s.s_k()}]") + self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_n()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + + + + + # warp around the really dim_b length, in case pad + # if self.tunable.nxe != 0: + # self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_ho()}], s[{s.s_wo()}]") + # self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_tmp(4)}]") + # self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") + # self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") + # else: + # self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") # no pad # for gemm_m pad # self._emit_empty_line() @@ -1146,8 +1158,7 @@ def emit_kernel_prologue(self): self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") # calculate group index TODO: use blockIdx.y as group index - self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_dim_b()}], s[{s.s_n()}]") - self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_m()}], {igemm_log2(self.tunable.gemm_m_per_block)}") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_k()}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: @@ -1160,7 +1171,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mov_b32 s[{s.s_bx()}], s[{s.s_tmp(4)}]") if self.tunable.source_access_order == IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N: - self._emit(f"s_lshr_b32 s[0], s[{s.s_k()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_n()}], {igemm_log2(self.tunable.gemm_n_per_block)}") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) @@ -1168,9 +1179,7 @@ def emit_kernel_prologue(self): self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) else: - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_dim_b()}], s[{s.s_n()}]") - - self._emit(f"s_lshr_b32 s[0], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_m()}], {igemm_log2(self.tunable.gemm_m_per_block)}") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") self._emit(m_mdiv_u32_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) @@ -1178,36 +1187,98 @@ def emit_kernel_prologue(self): self._emit(m_int_div_rem_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) self._emit(f"; s_tmp+4:block_gtc_in, s_tmp+5:block_gtc_im") - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(5)}], {igemm_log2(self.tunable.gemm_m_per_block)}") - # to compute ho*wo*sub_n1 // na_nb1 - if unmerge_sub_n1 == 1: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1)} ; total number of nb1") + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[0], s[{s.s_ho()}], s[{s.s_wo()}]") else: - assert na_nb1 >= unmerge_sub_n1 - if unmerge_sub_n1 == na_nb1: - self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of nb1") - else: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1 // unmerge_sub_n1)} ; total number of nb1") + self._emit(f"s_mul_i32 s[0], s[{s.s_hi()}], s[{s.s_wi()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) + assert False else: - self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_ss(s.s_tmp(4), '1', s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(f"s_cmp_lt_u32 s[1], s[{s.s_n()}]") # always compare last dim, here is n + self._emit(f"s_cbranch_scc0 {self.label_out}") # jump to end, basically this would not happen, but happen within block - if na_nb1 != 1: - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_nb1)}") - if na_nb0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_nb0)}") - self._emit_empty_line() + if self.tunable.nxe != 0: + # add n + self_emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[1], {igemm_log2(data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_tmp(2)}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_tmp(2)}], s[{s.s_in_stride_n()}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(m_int_div_rem_ss('0', '1', s.s_tmp(4), s.s_wo(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) + # s0:s_iwo, s1:s_iho, ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h, + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s1, s[{s.s_stride_h()}]") + self._emit(f"s_sub_i32 s[{s.s_in_ihi()}], s[{s.s_tmp(0)}], s[{s.s_pad_h()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s0, s[{s.s_stride_w()}]") + self._emit(f"s_sub_i32 s[{s.s_in_iwi()}], s[{s.s_tmp(1)}], s[{s.s_pad_w()}]") + + + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_wi()}], s[{s.s_ihi()}]") + self._emit(f"s_add_u32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], s[{s.s_iwi()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_in_stride_wi()}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_tmp(4)}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_tmp(4)}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp(5)}], s[{s.s_ho()}], s[{s.s_wo()}]") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + assert False + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in, v.v_gtc_ta_inb1(), s.s_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(f"v_mul_lo_u32 v[{v.v_in_os()}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") # not shifted yet + # v_tmp5, n, v_tmp4, ho*wo + self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + + #self._emit(f"v_add_u32 v[{v.v_in_iho()}], s1, v[{v.v_in_iho()}]") + #self._emit(f"v_add_u32 v[{v.v_in_iwo()}], s0, v[{v.v_in_iwo()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(0)}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_stride_w()}], v[{v.v_in_iwo()}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi()}], v[{v.v_tmp(0)}], s[{s.s_pad_h()}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwi()}], v[{v.v_tmp(1)}], s[{s.s_pad_w()}]") + + + else: + # 1x1, can treat nhw as a single dimension + self._emit(f"s_mul_i32 s[{s.s_tmp()}], {data_byte}, s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], {data_byte}, s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + + + + # if unmerge_sub_n1 == 1: + # self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1)} ; total number of nb1") + # else: + # assert na_nb1 >= unmerge_sub_n1 + # if unmerge_sub_n1 == na_nb1: + # self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of nb1") + # else: + # self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1 // unmerge_sub_n1)} ; total number of nb1") + + + # if IGEMM_GTC_FEAT_MAGIC_DIVISION: + # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + # self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) + # else: + # self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) + # if na_nb1 != 1: + # self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_nb1)}") + # if na_nb0 != 1: + # self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_nb0)}") + # self._emit_empty_line() self._emit(f"; in nb1 transform") if ca_nb1 == 1: self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}]") else: - self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_ta_in1b()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_ta_inb1()}]") if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") @@ -1346,9 +1417,9 @@ def emit_kernel_prologue(self): assert False else: if ca_nb0 == 1: - self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in1b()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_inb1()}]") else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_gtc_ta_in1b()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_gtc_ta_inb1()}]") self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_nb0*na_nb1)}, v[{v.v_tmp()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit_empty_line() From 05219e40358d89b89eb69f42a3c6f2964b0406a9 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 24 Jan 2021 17:26:09 +0800 Subject: [PATCH 08/40] update code --- igemm/algo/global_memory.py | 70 +++++ igemm/algo/igemm_fwd_gtc_nhwc.py | 520 +++++++++++++------------------ igemm/algo/shared_memory.py | 7 +- 3 files changed, 295 insertions(+), 302 deletions(-) diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index 100b98b2..d3abca2a 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -27,6 +27,21 @@ import sys from ..codegen import * +class inst_global_load_dword_t(object): + def __init__(self, dwords): + self.dwords = dwords + + def __call__(self, vdst, vaddr, saddr, offset = 0): + if self.dwords == 1: + return f"global_load_dword v[{vdst}], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 2: + return f"global_load_dwordx2 v[{vdst}:{vdst}+1], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 3: + return f"global_load_dwordx3 v[{vdst}:{vdst}+2], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + if self.dwords == 4: + return f"global_load_dwordx4 v[{vdst}:{vdst}+3], v[{vaddr}:{vaddr}+1], s[{srsrc}:{srsrc}+1], offset:{offset}" + assert False + class inst_buffer_load_dword_t(object): ''' TODO: this implementation always offen ''' def __init__(self, dwords): @@ -353,6 +368,61 @@ def get_issues(self): n_d1 = ctrl.length_d1 // ctrl.vector_d1 return ctrl.length_d0 * n_d1 +class macro_igemm_2d_global_load_precache_voffset_t(macro_base_t): + ''' + not support src/dst order + ''' + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("v_os") + self.declare_arg("v_flag") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_voffset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + i_cnt = 0 + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if self.v_flag != None: + self._emit(f"v_cmpx_eq_u32 vcc, 1, v[{self.v_flag(i_cnt)}]") + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_cnt*ctrl.vector_d1}", f"{self.v_os(i_cnt)}", f"{self.s_ptr()}", 0, 0)) + if self.v_flag != None: + self._emit(f"s_mov_b32 exec, -1") + i_cnt += 1 + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + class macro_igemm_write_4d_strided_t(macro_base_t): ''' TODO: this is always not inline diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 5b210ee7..e5786b09 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -183,6 +183,24 @@ def expr(self): self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, 1, vcc") self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") + + class macro_set_flag_nhw(macro_base_t): + def __init__(self, mc, inline = False): + macro_base_t.__init__(self, mc, inline) + self.declare_arg("v_flag") + self.declare_arg("v_flag_n") + self.declare_arg("v_ih") + self.declare_arg("v_iw") + self.declare_arg("s_h") + self.declare_arg("s_w") + def name(self): + return '.v_fwd_gtc_nhwc_set_flag_nhw' + + def expr(self): + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_h()}], v[{self.v_ih()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag_n()}], vcc") + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") class macro_in_update_hw_t(macro_base_t): def __init__(self, mc, inline = False): @@ -311,32 +329,15 @@ def __call__(self): v = self.outer.vgpr m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() - s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.outer.get_symbol_global_load_s_stride_d0_d1() with self._deferred_context(): self._emit(f"; load input") if self.outer.tunable.nxe != 0: self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") - self._emit(f"v_cmp_eq_u32 vcc, 1, v[{v.v_in_flag()}]") - self._emit(f"s_and_saveexec_b64 s[{s.s_tmp(4)}:{s.s_tmp(5)}], vcc") - if self.outer.tunable.precache_soffset: - self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset())) + self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), v.v_in_flag())) else: - self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), s_in_stride_d0(), s_in_stride_d1(), s.s_tmp())) - if self.outer.tunable.nxe != 0: - self._emit(f"s_or_b64 exec, exec, s[{s.s_tmp(4)}:{s.s_tmp(5)}]") - return self._get_deferred() + self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), None)) - # def is_1d_move_slice_k(self): - # ''' - # this now only meaning for input tensor - # ''' - # na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() - # if self.tunable.nxe != 0: - # return False # if not nxe 0, it is possible that we can do move slice, but that will lead to extra index calculation - # if nb_c1e != 1 and nb_c0 == 1: - # return True - # # it is meanless to let n_c1e==1 and n_c0!=1 - # return False + return self._get_deferred() class global_load_wei_t(mc_base_t): def __init__(self, mc, outer): @@ -469,33 +470,31 @@ def __init__(self, mc, outer): self.s_group = sym_t('s_group' , sseq(1)) # stride for in - self.s_in_stride_hi = sym_t('s_in_stride_hi' , sseq(1)) + # self.s_in_stride_hi = sym_t('s_in_stride_hi' , sseq(1)) self.s_in_stride_wi = sym_t('s_in_stride_wi' , sseq(1)) self.s_in_stride_n = sym_t('s_in_stride_n' , sseq(1)) - if ta_nb0 != 1: - self.s_in_stride_n0 = sym_t('s_in_stride_n0' , sseq(1)) # stride for wei if tb_k0 != 1: self.s_wei_stride_k0 = sym_t('s_wei_stride_k0' , sseq(1)) self.s_wei_stride_k = sym_t('s_wei_stride_k' , sseq(1)) - if outer.tunable.nxe != 0: - self.s_wei_stride_y = sym_t('s_wei_stride_y' , sseq(1)) + #if outer.tunable.nxe != 0: + # self.s_wei_stride_y = sym_t('s_wei_stride_y' , sseq(1)) self.s_stride_c = sym_t('s_stride_c' , sseq(1)) # stride for out self.s_out_stride_wo = sym_t('s_out_stride_wo' , sseq(1)) self.s_out_stride_n = sym_t('s_out_stride_n' , sseq(1)) - if ta_nb0 != 1: - self.s_out_stride_n0 = sym_t('s_out_stride_n0' , sseq(1)) self.s_in_stride_c_c1 = sym_t("s_in_stride_c_c1" , sseq(1)) self.s_in_stride_c_c0_c1_diff = sym_t("s_in_stride_c_c0_c1_diff" , sseq(1)) self.s_block_gtc_ig = sym_t("s_block_gtc_ig" , sseq(1)) self.s_block_gtc_ik = sym_t("s_block_gtc_ik" , sseq(1)) - self.s_block_gtc_in0 = sym_t("s_block_gtc_in0" , sseq(1)) - self.s_block_gtc_in1b = sym_t("s_block_gtc_in1b" , sseq(1)) + self.s_block_gtc_inb = sym_t("s_block_gtc_inb" , sseq(1)) + + # self.s_block_gtc_in0 = sym_t("s_block_gtc_in0" , sseq(1)) + # self.s_block_gtc_in1b = sym_t("s_block_gtc_in1b" , sseq(1)) self.s_move_slice_k_c1e = sym_t("s_move_slice_k_c1e" , sseq(1)) if outer.tunable.nxe != 0: @@ -524,12 +523,27 @@ def __init__(self, mc, outer): self.s_dim_m = sym_t("s_dim_m" , sseq(1)) self.s_dim_n = sym_t("s_dim_n" , sseq(1)) + if outer.tunable.nxe != 0: + self.s_len_h = sym_t("s_len_h" , sseq(1)) + self.s_len_w = sym_t("s_len_w" , sseq(1)) + self.s_lim_h = sym_t("s_lim_h" , sseq(1)) # used to compare ih, will increase while y increase + self.s_lim_w = sym_t("s_lim_w" , sseq(1)) # used to compare iw, will increase while x increase + else: + self.s_len_h = sym_t("s_len_h" , self.s_hi.value) + self.s_len_w = sym_t("s_len_w" , self.s_wi.value) + self.s_lim_h = sym_t("s_lim_h" , self.s_hi.value) + self.s_lim_w = sym_t("s_lim_w" , self.s_wi.value) + + self.s_thread_stride_w = sym_t("s_thread_stride_w" , sseq(1)) + self.s_thread_stride_h = sym_t("s_thread_stride_h" , sseq(1)) + self.s_thread_stride_n = sym_t("s_thread_stride_n" , sseq(1)) + self.s_kitr = sym_t("s_kitr" , 1) if outer.tunable.precache_soffset: m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() - in_npc = m_in_2d_global_load.get_num_precache_soffset() + #in_npc = m_in_2d_global_load.get_num_precache_soffset() wei_npc = m_wei_2d_global_load.get_num_precache_soffset() - self.s_in_offset = sym_t("s_in_offset" ,sseq(in_npc)) # if this number is zero, it is also OK, since we would not use + #self.s_in_offset = sym_t("s_in_offset" ,sseq(in_npc)) # if this number is zero, it is also OK, since we would not use self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) # self.s_k_padded = sym_t("s_k_padded" ,sseq(1)) @@ -569,6 +583,8 @@ def __init__(self, mc, outer): ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = outer.get_thread_lengths() ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = outer.get_cluster_lengths() + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS vseq = gpr_sequencer_t() if is_vgpr_acc_c: @@ -592,41 +608,47 @@ def __init__(self, mc, outer): self.v_sst_b_os = sym_t("v_sst_b_os" ,vseq(1)) self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) - self.v_in_os = sym_t("v_in_os" ,vseq(1)) - self.v_in_os_base = sym_t("v_in_os_base" ,vseq(1)) + + # self.v_in_os_base = sym_t("v_in_os_base" ,vseq(1)) + self.v_in_os = sym_t("v_in_os" ,vseq(nb_per_thread)) if outer.tunable.nxe != 0: - self.v_in_flag = sym_t("v_in_flag" ,vseq(1)) + self.v_in_flag = sym_t("v_in_flag" ,vseq(nb_per_thread)) + self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) - self.v_gtc_ta_ic = sym_t("v_gtc_ta_ic" ,vseq(1)) + self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) #if ca_nb0 != 1: # self.v_gtc_ta_in0 = sym_t("v_gtc_ta_in0" ,vseq(1)) - self.v_gtc_ta_inb1 = sym_t("v_gtc_ta_inb1" ,vseq(1)) - self.v_gtc_ta_in1 = sym_t("v_gtc_ta_in1" ,vseq(1)) + self.v_in_inb = sym_t("v_in_inb" ,vseq(1)) + #self.v_gtc_ta_in1 = sym_t("v_gtc_ta_in1" ,vseq(1)) + + self.v_flag_n = sym_t("v_flag_n" ,vseq(1)) # this flag will not change while move_slice_window - if tb_k0 != 1: - self.v_gtc_tb_ik0 = sym_t("v_gtc_tb_ik0" ,vseq(1)) - self.v_gtc_tb_ik = sym_t("v_gtc_tb_ik" ,vseq(1)) + # if tb_k0 != 1: + # self.v_wei_ik0 = sym_t("v_wei_ik0" ,vseq(1)) + self.v_wei_ik = sym_t("v_wei_ik" ,vseq(1)) self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) - self.v_out_os = sym_t("v_out_os" ,vseq(1)) + self.v_out_os = sym_t("v_out_os" ,vseq(1)) if outer.tunable.nxe != 0: - self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) - self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) - self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) - self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) - - self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) - self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) - self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) - self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) + self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) + self.v_out_in0 = sym_t("v_out_in0" ,vseq(1)) + self.v_out_in1b = sym_t("v_out_in1b" ,vseq(1)) + self.v_out_in1 = sym_t("v_out_in1" ,vseq(1)) + + self.v_in_iho = sym_t("v_in_iho" ,vseq(1)) + self.v_in_iwo = sym_t("v_in_iwo" ,vseq(1)) + self.v_in_ihi = sym_t("v_in_ihi" ,vseq(1)) + self.v_in_iwi = sym_t("v_in_iwi" ,vseq(1)) + self.v_in_in = sym_t("v_in_in" ,vseq(1)) + if outer.tunable.nxe != 0: self.v_in_iy = sym_t("v_in_iy" ,vseq(1)) self.v_in_ix = sym_t("v_in_ix" ,vseq(1)) - self.v_move_slice_k_ic = sym_t("v_move_slice_k_ic1" , self.v_gtc_ta_ic.value) + self.v_move_slice_k_ic = sym_t("v_move_slice_k_ic1" , self.v_gtc_ic.value) if outer.tunable.nxe != 0: self.v_move_slice_k_iy = sym_t("v_move_slice_k_iy", self.v_in_iy.value) self.v_move_slice_k_ix = sym_t("v_move_slice_k_ix", self.v_in_ix.value) @@ -686,12 +708,9 @@ def get_thread_lengths(self): assert ta_e == 1, "currently not support >1 in e dimension" - if self.tunable.nxe == 0: - #assert ta_c0 == 1 - #assert tb_c0 == 1 - pass - else: - pass + # it's no point to have both x0, x1 have copy value + assert ta_nb0 != 1 and ta_nb1 != 1 + assert tb_k0 != 1 and tb_k1 != 1 return ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 # M, K, N @@ -704,6 +723,7 @@ def get_cluster_lengths(self): ca_e, ca_c, ca_nb0, ca_nb1 = c_ta[0], c_ta[1], c_ta[2], c_ta[3] cb_e, cb_c, cb_k0, cb_k1 = c_tb[0], c_tb[1], c_tb[2], c_tb[3] + assert ca_nb1 != 1 assert ca_e == cb_e and ca_c == cb_c assert ca_e == 1 and ca_nb0 == 1 and cb_k0 == 1 @@ -722,7 +742,7 @@ def get_dims_lengths(self): def get_thread_copy_dims(self): ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() in_thread_copy_dims = [ta_nb0, ta_nb1, ta_e, ta_c] - wei_thread_copy_dims = [tb_k0, tb_k1, ta_e, ta_c] # always reordered! + wei_thread_copy_dims = [tb_k0, tb_k1, ta_e, ta_c] # always reordered! return in_thread_copy_dims, wei_thread_copy_dims def get_thread_copy_index(self): @@ -759,8 +779,7 @@ def get_macro_global_load(self): ctrl_wei_gld.length_d0 = 1 ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] else: - ctrl_wei_gld.length_d0 = 1 - ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] + assert False if self.in_thread_copy_ndim == 2: ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] @@ -769,15 +788,13 @@ def get_macro_global_load(self): ctrl_in_gld.length_d0 = 1 ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] else: - ctrl_in_gld.length_d0 = 1 - ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] + assert False if self.tunable.precache_soffset: return macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ - macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_in_gld, inline) + macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) else: - return macro_igemm_2d_global_load_t(self.mc, ctrl_wei_gld, inline), macro_igemm_2d_global_load_t(self.mc, ctrl_in_gld, inline) - + return macro_igemm_2d_global_load_t(self.mc, ctrl_wei_gld, inline), macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) def get_macro_shared_store(self): #in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() @@ -805,7 +822,7 @@ def get_macro_shared_store(self): wei_sst_ctrl.stride_d1 = k_pack * data_byte inline = True if self.tunable.fma_interleave else False - return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_2d_shared_store_t(self.mc, wei_sst_ctrl, inline) + return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) # computation macro def get_macro_in_update_hw(self): @@ -830,6 +847,10 @@ def get_macro_set_flag_hw(self): inline = True if self.tunable.fma_interleave else False return self.macro_set_flag_hw(self.mc, inline) + def get_macro_set_flag_nhw(self): + inline = True if self.tunable.fma_interleave else False + return self.macro_set_flag_nhw(self.mc, inline) + def get_symbol_global_load_s_stride_d0_d1(self): ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() # get the symbol object that load 2d may use @@ -837,9 +858,10 @@ def get_symbol_global_load_s_stride_d0_d1(self): s_dummy = sym_t("s_dummy") in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() + # input is ignored # [ta_nb0, ta_nb1, ta_e, ta_c] - in_stride_gprs = [s_dummy, # complecated - s_dummy, # complecated + in_stride_gprs = [s_dummy, + s_dummy, s_dummy, s.s_stride_c] @@ -856,8 +878,7 @@ def get_symbol_global_load_s_stride_d0_d1(self): s_in_stride_d0 = s_dummy s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[0]] else: - s_in_stride_d0 = s_dummy - s_in_stride_d1 = in_stride_gprs[-1] + assert False if self.wei_thread_copy_ndim == 2: # print(f" ____ wei_thread_copy_index:{len(wei_thread_copy_index)}, {wei_thread_copy_index}") @@ -867,12 +888,10 @@ def get_symbol_global_load_s_stride_d0_d1(self): s_wei_stride_d0 = s_dummy s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[0]] else: - s_wei_stride_d0 = s_dummy - s_wei_stride_d1 = wei_stride_gprs[-1] + assert False return s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 - def get_kernel_code(self): kernel_code = amdgpu_kernel_code_t({ 'enable_sgpr_kernarg_segment_ptr' : 1, @@ -990,31 +1009,18 @@ def emit_kernel_prologue(self): s = self.sgpr v = self.vgpr k = self.karg - # gemm_m_unmerge_cluster = self.tunable.gemm_m_unmerge_cluster - # gemm_n_unmerge_cluster = self.tunable.gemm_n_unmerge_cluster - # gemm_k_unmerge_cluster = self.tunable.gemm_k_unmerge_cluster - - # assert gemm_m_unmerge_cluster == 0 and gemm_n_unmerge_cluster == 0 and gemm_k_unmerge_cluster == 0, 'in fwd nhwc, gemm_m/n/k unmerge_cluster only support 0' ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() - # unmerge_sub_n = self.tunable.unmerge_sub_n - # if gemm_n_unmerge_cluster == 0: - # assert unmerge_sub_n % na_nb0 == 0, f"unmerge_sub_n:{unmerge_sub_n}, na_nb0:{na_nb0}" - # unmerge_sub_n1 = unmerge_sub_n // na_nb0 - # assert na_nb1 % unmerge_sub_n1 == 0, f"na_nb1:{na_nb1}, unmerge_sub_n1:{unmerge_sub_n1}" - # else: - # assert False, f"unsupported gemm_n_unmerge_cluster:{self.tunable.gemm_n_unmerge_cluster}" - data_byte = amdgpu_precision_data_byte(self.tunable.precision) m_in_update_hw = self.get_macro_in_update_hw() m_in_update_os = self.get_macro_in_update_os() - # m_wei_update_os = self.get_macro_wei_update_os() - # m_wei_update_yx = self.get_macro_wei_update_yx() - m_set_flag_hw = self.get_macro_set_flag_hw() + + m_set_flag_hw = self.get_macro_set_flag_hw() + m_set_flag_nhw = self.get_macro_set_flag_nhw() s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.get_symbol_global_load_s_stride_d0_d1() m_wei_2d_global_load, m_in_2d_global_load = self.get_macro_global_load() @@ -1022,9 +1028,6 @@ def emit_kernel_prologue(self): tc_index_dispatcher = igemm_thread_cluster_index_dispatcher_t(self.mc) tc_index_accumulator = igemm_thread_cluster_index_accumulator_t(self.mc) - m_int_div_rem_vv = macro_int_div_rem_vv_t(self.mc) - m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) - m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) if IGEMM_GTC_FEAT_MAGIC_DIVISION: m_mdiv_u32_vs = macro_mdiv_u32_rem_vs_t(self.mc) @@ -1034,7 +1037,6 @@ def emit_kernel_prologue(self): m_int_div_rem_vs = macro_int_div_rem_vs_t(self.mc) m_int_div_rem_ss = macro_int_div_rem_ss_t(self.mc) - # gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() s_dummy = sym_t("s_dummy") # start emit @@ -1056,15 +1058,15 @@ def emit_kernel_prologue(self): self._emit(f"s_load_dword s[{s.s_magic_6()}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_6()}") self._emit(f"s_load_dwordx2 s[{s.s_shift_pack_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") - self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, unmerge_sub_n:{unmerge_sub_n}, unmerge_sub_n1:{unmerge_sub_n1}") + self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") - self._emit(tc_index_dispatcher(v.v_gtc_ta_ic(), v.v_tmp(), ca_c, ta_c)) - self._emit(tc_index_dispatcher(v.v_gtc_ta_inb1(), v.v_tmp(), ca_nb1, ta_nb1, True)) + self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), ca_c, ta_c)) + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1, True)) self._emit(f"; wei(e, c, k0, k1) thread_length: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, cluster_length: {ca_e}x{ca_c}x{cb_k0}x{cb_k1}") # weight ic same as input self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") - self._emit(tc_index_dispatcher(v.v_gtc_tb_ik(), v.v_tmp(), cb_k, tb_k, True)) + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k, tb_k, True)) self._emit_empty_line() self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") @@ -1084,82 +1086,52 @@ def emit_kernel_prologue(self): self._emit(f"; calculate index") # calculate stride, not shift data byte yet + # input + self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(2)}]") + + # weight if self.tunable.nxe != 0: - # input - self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") - self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") - if ta_nb0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") - # weight self._emit(f"s_mul_i32 s[{s.s_wei_stride_y()}], s[{s.s_x()}], s[{s.s_c()}]") self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_wei_stride_y()}], s[{s.s_y()}]") - # output - self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo()}], s[{s.s_out_stride_wo()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho()}], s[{s.s_tmp(1)}]") - if ta_nb0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") - else: - # input - self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_in_stride_hi()}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") - self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_in_stride_hi()}]") - if ta_nb0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_in_stride_n0()}], s[{s.s_in_stride_n()}], {utility_log2(unmerge_sub_n1)}") - # weight self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") - # output - self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wi()}], s[{s.s_out_stride_wo()}]") - self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(1)}]") - if ta_nb0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], {utility_log2(unmerge_sub_n1)}") + + if tb_k0 != 1: + self._emit(f"s_lshl_b32 s[{s.s_wei_stride_K0()}], s[{s.s_wei_stride_K()}], {igemm_log2(nb_k1)}") + + # output + self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], s[{s.s_out_stride_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho() if self.tunable.nxe != 0 else s.s_hi()}], s[{s.s_tmp(1)}]") + + # TODO: accumulate splited batch here # early init s_knum in case shifted self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}]") # pad gemm_m, gemm_n if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_ho()}], s[{s.s_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_ho()}], s[{s.s_wo()}]") else: - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_hi()}], s[{s.s_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_n()}], s[{s.s_tmp(1)}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_n()}], s[{s.s_dim_b()}]") self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_tmp(2)}]") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") self._emit(f"s_lshl_b32 s[{s.s_dim_m()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_m_per_block)}") self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_n_per_block - 1}, s[{s.s_k()}]") - self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") - self._emit(f"s_lshl_b32 s[{s.s_dim_n()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") - - - - - # warp around the really dim_b length, in case pad - # if self.tunable.nxe != 0: - # self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_ho()}], s[{s.s_wo()}]") - # self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.nxb - 1}, s[{s.s_tmp(4)}]") - # self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.nxb)}") - # self._emit(f"s_lshl_b32 s[{s.s_dim_b()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.nxb)}") - # else: - # self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") # no pad - - # for gemm_m pad - # self._emit_empty_line() - # self._emit(f"; pad k if need") - # self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_k()}]") - # self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") - # self._emit(f"s_lshl_b32 s[{s.s_k_padded()}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_n()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit_empty_line() self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") - # calculate group index TODO: use blockIdx.y as group index + # calculate group index self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_m()}], {igemm_log2(self.tunable.gemm_m_per_block)}") - self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_k()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_n()}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080010 ; offset:16, width:8") @@ -1188,121 +1160,44 @@ def emit_kernel_prologue(self): self._emit(f"; s_tmp+4:block_gtc_in, s_tmp+5:block_gtc_im") self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ik()}], s[{s.s_tmp(4)}], {igemm_log2(self.tunable.gemm_n_per_block)}") - self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(5)}], {igemm_log2(self.tunable.gemm_m_per_block)}") - - if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[0], s[{s.s_ho()}], s[{s.s_wo()}]") - else: - self._emit(f"s_mul_i32 s[0], s[{s.s_hi()}], s[{s.s_wi()}]") - - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - assert False - else: - self._emit(m_int_div_rem_ss(s.s_tmp(4), '1', s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) - self._emit(f"s_cmp_lt_u32 s[1], s[{s.s_n()}]") # always compare last dim, here is n - self._emit(f"s_cbranch_scc0 {self.label_out}") # jump to end, basically this would not happen, but happen within block - - if self.tunable.nxe != 0: - # add n - self_emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[1], {igemm_log2(data_byte)}") - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_tmp(2)}], s[{s.s_in_stride_n()}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_tmp(2)}], s[{s.s_in_stride_n()}]") - self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - - self._emit(m_int_div_rem_ss('0', '1', s.s_tmp(4), s.s_wo(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) - # s0:s_iwo, s1:s_iho, ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h, - self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s1, s[{s.s_stride_h()}]") - self._emit(f"s_sub_i32 s[{s.s_in_ihi()}], s[{s.s_tmp(0)}], s[{s.s_pad_h()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s0, s[{s.s_stride_w()}]") - self._emit(f"s_sub_i32 s[{s.s_in_iwi()}], s[{s.s_tmp(1)}], s[{s.s_pad_w()}]") - - - self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_wi()}], s[{s.s_ihi()}]") - self._emit(f"s_add_u32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], s[{s.s_iwi()}]") - self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_in_stride_wi()}], {igemm_log2(data_byte)}") - - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_tmp(4)}], s[{s.s_tmp(5)}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_tmp(4)}], s[{s.s_tmp(5)}]") - self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - - self._emit(f"s_mul_i32 s[{s.s_tmp(5)}], s[{s.s_ho()}], s[{s.s_wo()}]") - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - assert False - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in, v.v_gtc_ta_inb1(), s.s_tmp(5), v.v_tmp(), s.s_tmp())) - self._emit(f"v_mul_lo_u32 v[{v.v_in_os()}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") # not shifted yet - # v_tmp5, n, v_tmp4, ho*wo - self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) - - #self._emit(f"v_add_u32 v[{v.v_in_iho()}], s1, v[{v.v_in_iho()}]") - #self._emit(f"v_add_u32 v[{v.v_in_iwo()}], s0, v[{v.v_in_iwo()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(0)}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_stride_w()}], v[{v.v_in_iwo()}]") - self._emit(f"v_sub_i32 v[{v.v_in_ihi()}], v[{v.v_tmp(0)}], s[{s.s_pad_h()}]") - self._emit(f"v_sub_i32 v[{v.v_in_iwi()}], v[{v.v_tmp(1)}], s[{s.s_pad_w()}]") - + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_inb()}], s[{s.s_tmp(5)}], {igemm_log2(self.tunable.gemm_m_per_block)}") - else: - # 1x1, can treat nhw as a single dimension - self._emit(f"s_mul_i32 s[{s.s_tmp()}], {data_byte}, s[{s.s_tmp(5)}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], {data_byte}, s[{s.s_tmp(5)}]") - self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - - - - - # if unmerge_sub_n1 == 1: - # self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1)} ; total number of nb1") - # else: - # assert na_nb1 >= unmerge_sub_n1 - # if unmerge_sub_n1 == na_nb1: - # self._emit(f"s_mov_b32 s[0], s[{s.s_dim_b()}] ; total number of nb1") - # else: - # self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_b()}], {igemm_log2(na_nb1 // unmerge_sub_n1)} ; total number of nb1") - - - # if IGEMM_GTC_FEAT_MAGIC_DIVISION: - # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") - # self._emit(m_mdiv_u32_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), s.s_magic_1(), s.s_tmp(3), '0', s.s_tmp())) - # else: - # self._emit(m_int_div_rem_ss(s.s_block_gtc_in1b(), s.s_block_gtc_in0(), s.s_tmp(5), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) - # if na_nb1 != 1: - # self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in1b()}], s[{s.s_block_gtc_in1b()}], {igemm_log2(na_nb1)}") - # if na_nb0 != 1: - # self._emit(f"s_lshl_b32 s[{s.s_block_gtc_in0()}], s[{s.s_block_gtc_in0()}], {igemm_log2(na_nb0)}") - # self._emit_empty_line() - - self._emit(f"; in nb1 transform") - if ca_nb1 == 1: - self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}]") - else: - self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_gtc_ta_inb1()}]") + # transform nb + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_in_inb()}]") if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") self._emit(m_mdiv_u32_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w self._emit(f"v_mul_lo_u32 v[{v.v_in_iho()}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") - self._emit(f"v_sub_i32 v[{v.v_in_iho()}], v[{v.v_in_iho()}], s[{s.s_pad_h()}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi()}], v[{v.v_in_iho()}], s[{s.s_pad_h()}]") self._emit(f"v_mul_lo_u32 v[{v.v_in_iwo()}], s[{s.s_stride_w()}], v[{v.v_in_iwo()}]") - self._emit(f"v_sub_i32 v[{v.v_in_iwo()}], v[{v.v_in_iwo()}], s[{s.s_pad_w()}]") - self._emit(m_in_update_hw(v.v_in_ihi(), v.v_in_iwi(), v.v_in_iho(), v.v_in_iwo(), v.v_in_iy(), v.v_in_ix(), s.s_dilation_h(), s.s_dilation_w())) + self._emit(f"v_sub_i32 v[{v.v_in_iwi()}], v[{v.v_in_iwo()}], s[{s.s_pad_w()}]") self._emit_empty_line() + else: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") self._emit(m_mdiv_u32_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_gtc_ta_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + ''' + from here, need track ihi, iwi in move slice window + ''' + + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_n()}], v[{self.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{self.v_flag_n()}], 0, 1, vcc") self._emit(f"; calculate in offset") # compute group distance @@ -1311,52 +1206,84 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - - self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_in_stride_n()}], s[{s.s_tmp(3)}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_in_stride_n()}], s[{s.s_tmp(3)}]") - self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - self._emit_empty_line() - #if gemm_m_unmerge_cluster == 0: - if ca_nb0 != 1: - self._emit(tc_index_accumulator(v.v_tmp(1), v.v_gtc_ta_in0(), v.v_gtc_ta_in1(), ca_nb0, ca_nb1, 0, unmerge_sub_n1)) - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_tmp(1)}]") - else: - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") - # else: - # # no in0 - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_gtc_ta_in1()}]") - + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") # s_in_stride_wi need shift before! self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) + + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") if self.tunable.nxe != 0: - self._emit(f"v_add_lshl_u32 v[{v.v_in_os_base()}], v[{v.v_gtc_ta_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") - self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") - self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_in_os_base()}], v[{v.v_tmp()}]") - - self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) - else: - self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ta_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") - self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") - self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + self._emit(m_set_flag_nhw(v.v_in_flag(), v.v_flag_n(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) self._emit_empty_line() - if self.in_thread_copy_ndim != 1: - if s_in_stride_d0 != s_dummy: - self._emit(self.try_shift_stride(s_in_stride_d0, igemm_log2(data_byte))) - if s_in_stride_d1 != s_dummy: - self._emit(self.try_shift_stride(s_in_stride_d1, igemm_log2(data_byte))) - self._emit_empty_line() + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_len_h()}], s[{s.s_ho()}], s[{s.s_stride_h()}]") + self._emit(f"s_mul_i32 s[{s.s_len_w()}], s[{s.s_wo()}], s[{s.s_stride_w()}]") + self._emit(f"s_mov_b32 s[{s.s_lim_h()}], s[{s.s_len_h()}]") + self._emit(f"s_mov_b32 s[{s.s_lim_w()}], s[{s.s_len_w()}]") + + # voffset + if ta_nb0 != 1 or ta_nb1 != 1: + thread_stride = na_nb1 if ta_nb0 != 1 else 1 + self._emit(f"s_mov_b32 s[{s.s_tmp(5)}], {thread_stride}") + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), s.s_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), s.s_tmp())) + else: + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_dim_b(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_wo(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) - if self.tunable.precache_soffset: - self._emit(m_in_2d_global_load.init_precache_soffset(s_in_stride_d0(), s_in_stride_d1(), s.s_in_offset(), s.s_tmp())) + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_thread_stride_h()}], s[{s.s_thread_stride_h()}], s[{s.s_stride_h()}]") + self._emit(f"s_mul_i32 s[{s.s_thread_stride_w()}], s[{s.s_thread_stride_w()}], s[{s.s_stride_w()}]") + + # now let's precompute all the voffset + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v[{v.v_in_ihi()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp(3)}], v[{v.v_in_in()}]") + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + for i in range(1, nb_per_thread): + # v_tmp+4:ihi, v_tmp+5:iwi + self._emit(f"v_add_i32 v[{v.v_tmp(4)}], s[{s.s_thread_stride_w()}], v[{v.v_in_iwi() if i == 1 else v.v_tmp(4) }]") + self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_w()}], v[{v.v_tmp(4)}]") + self._emit(f"v_subrev_i32 v[{v.v_tmp(4)}], s[{s.s_len_w()}], v[{v.v_tmp(4)}]") + if self.tunable.nxe != 0: + self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_stride_h()}], v[{v.v_tmp(5)}]") + else: + self._emit(f"v_add_i32 v[{v.v_tmp(5)}], 1, v[{v.v_tmp(5)}]") + self._emit(f"s_mov_b64 exec, -1") + + self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_thread_stride_h()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_h()}], v[{v.v_tmp(5)}]") + self._emit(f"v_subrev_i32 v[{v.v_tmp(5)}], s[{s.s_len_h()}], v[{v.v_tmp(5)}]") + self._emit(f"v_add_u32 v[{v.v_tmp(3)}], 1, v[{v..v_tmp(3)}]") + self._emit(f"s_mov_b64 exec, -1") + + self._emit(f"v_add_u32 v[{v.v_tmp(3)}], s[{s.s_thread_stride_n()}], v[{v.v_tmp(3)}]") + + if self.tunable.nxe != 0: + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_tmp(3)}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(m_set_flag_nhw(v.v_flag(i), v.v_tmp(), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) + + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v.v_tmp(3)}]") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(2)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_tmp(5)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os(i)}], v[{v.v_tmp(2)}], v[{v.v_tmp()}]") + + else: + pass # load in self._emit(self.global_load_in()) @@ -1376,17 +1303,9 @@ def emit_kernel_prologue(self): self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") - self._emit(f"v_add_u32 v[{v.v_cur_k()}], s[{s.s_block_gtc_ik()}], v[{v.v_gtc_tb_ik()}]") + self._emit(f"v_add_u32 v[{v.v_cur_k()}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_cur_k()}]") - - self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(data_byte)}") - - # self._emit(m_wei_update_os(v.v_wei_os(), v.v_wei_os_base(), v.v_wei_iy(), v.v_wei_ix(), s.s_x(), v.v_tmp())) - #else: - # self._emit(tc_index_accumulator(v.v_tmp(), v.v_gtc_ik0(), v.v_gtc_ta_ik1(), ca_k0, ca_k1, na_k0, na_k1)) - # self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_tmp()}]") - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") - # self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic1e()}], {igemm_log2(data_byte)}") + self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") self._emit_empty_line() if self.wei_thread_copy_ndim != 1: @@ -1411,21 +1330,24 @@ def emit_kernel_prologue(self): self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) + ''' + gemm_k * gemm_m * k_pack + ''' self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") if ca_nb1 == 1: # TODO: remove this path, not possible go here assert False else: if ca_nb0 == 1: - self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_inb1()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_in_inb()}]") else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_gtc_ta_inb1()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(na_nb0*na_nb1)}, v[{v.v_tmp()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1)}, v[{v.v_tmp()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit_empty_line() self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k}, {ca_e}x{ca_c}x{cb_k}") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_ic()}], {igemm_log2(nb_k)}, v[{v.v_gtc_tb_ik()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k)}, v[{v.v_wei_ik()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") self._emit_empty_line() diff --git a/igemm/algo/shared_memory.py b/igemm/algo/shared_memory.py index 58d7849b..0d95c97b 100644 --- a/igemm/algo/shared_memory.py +++ b/igemm/algo/shared_memory.py @@ -866,8 +866,8 @@ def serialize(self): class macro_igemm_3d_shared_store_t(macro_base_t): ''' this is indeed for - 0: gemm_k * gemm_m/n * k_pack - 1: gemm_m/n * gemm_k * k_pack + 0: gemm_k * gemm_m/n * k_pack, src_order = 0 + 1: gemm_m/n * gemm_k * k_pack, src_order = 1 (unsupported) we always want to use k_pack as vector store ''' def __init__(self, mc, ctrl, inline = False): @@ -913,8 +913,9 @@ def expr(self): self._emit(ds_write2(f'{self.v_sst_os()}', f'{self.v_src()}+{2 * i_d*ctrl.length_dp}', 2 * i_d * stride_d)) issue_cnt += ds_write2.get_issues(2 * i_d * stride_d) else: + # nhwc almost all case goes here for i_d in range(length_d): - self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{s_id}', i_d0 * ctrl.stride_d0)) + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{i_d*ctrl.length_dp}', i_d * ctrl.stride_d)) issue_cnt += ds_write.get_issues() else: assert False, "un implemented yet" From a782aaa78372df83517ffc5bb7d9a0e04bdcc2a1 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 24 Jan 2021 23:11:00 +0800 Subject: [PATCH 09/40] support kpack in main loop --- igemm/algo/mfma_main_loop.py | 355 +++++++++++++++++++---------------- 1 file changed, 193 insertions(+), 162 deletions(-) diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 13f7d61c..9dbb2211 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -67,6 +67,11 @@ def __init__(self): self.s_kitr = None self.s_knum = None + # below is in unit of pixel, not considered data_type bytes + self.lds_k_pack = 1 + self.lds_pad_m = 0 # pad how many pixels per m row + self.lds_pad_n = 0 # pad how many pixels per n row + class mfma_main_loop_t(mc_base_t): ''' ''' @@ -119,6 +124,22 @@ def emit(self): unroll_k = self.ctrl.unroll_k k_per_inst = cxm.block_k() + pad_m = self.ctrl.lds_pad_m + pad_n = self.ctrl.lds_pad_n + + def mapped_ioffset(i_k, width_byte, pad_pixel, offset = 0): + k_pack = self.ctrl.lds_k_pack + i_k0 = i_k // k_pack + i_kp = i_k % k_pack + return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * k_pack * data_byte + offset + + # mi = mapped_ioffset + def mi_m(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_m, pad_m, offset) + + def mi_n(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_n, pad_n, offset) + def mfma_step_mxn(i_repeat_m, i_repeat_n, i_local_buffer_m = 0, i_local_buffer_n = 0): local_buffer_m = cxm.inst_mfma.num_v_a * cxm.wave_step_m * cxm.wave_repeat_m local_buffer_n = cxm.inst_mfma.num_v_b * cxm.wave_step_n * cxm.wave_repeat_n @@ -165,21 +186,21 @@ def mfma_loop_repeat_1x1_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # lds_width_n * k_per_inst def do_unroll_k_1x1_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 for i_k in range(unroll_k_sub): self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))) # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst))) # (2*i_k+2) * lds_width_n * k_per_inst self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_n * k_per_inst do_unroll_k_1x1_sub() self._emit(f_move_slice_window_b()) @@ -221,8 +242,8 @@ def do_unroll_k_1x1_sub(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # k_per_inst * lds_width_n do_unroll_k_1x1_sub() self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) @@ -243,13 +264,13 @@ def do_interleave_unroll_k_sub(): for i_k in range(unroll_k_sub): self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))) # (2*i_k+2) * k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst))) # (2*i_k+2) * k_per_inst * lds_width_n self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * k_per_inst * lds_width_n return self._get_deferred() def do_interleave_gload_and_move_slice_window(): @@ -304,8 +325,8 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst) )) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst) )) # k_per_inst * lds_width_n if (unroll_k // k_per_inst) // 2 - 1 != 0: @@ -338,8 +359,8 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + k_per_inst * lds_width_m)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + k_per_inst * lds_width_n)) + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst))) # k_per_inst * lds_width_m + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst))) # k_per_inst * lds_width_n self._emit(do_interleave_unroll_k_sub()) self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 0, 0)) @@ -374,8 +395,8 @@ def mfma_loop_repeat_2x2_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -385,65 +406,65 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({3 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_n * k_per_inst + lds_width_n // 2 + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst, lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt({4 if i_k == 0 else 5})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 0, 0)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2) + \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2)) + \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2 self._emit_empty_line() self._emit(f"; k iteration : {2 * i_k + 1}") # 1st fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2) )+ \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 1, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst) ) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(1, 0, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n//2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") # (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2 if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((unroll_k // k_per_inst - 1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") # (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() do_unroll_k_sub() @@ -525,8 +546,8 @@ def do_unroll_k_sub(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2}") # 1st fma @@ -585,52 +606,52 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({3 if i_k == 0 else 5})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst ,lds_width_n // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_n * k_per_inst + lds_width_n // 2 + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst , lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma self._emit(f's_waitcnt lgkmcnt({4 if i_k == 0 else 5})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") # (2*i_k+2) * lds_width_n * k_per_inst self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 0, 0)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2) + \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2)) + \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2 self._emit_empty_line() self._emit(f"; k iteration : {2 * i_k + 1}") # 1st fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ \ - f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ \ + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") # (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2 self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(0, 1, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") # (2*i_k+3) * lds_width_m * k_per_inst self._emit_empty_line() # 3rd fma @@ -641,7 +662,7 @@ def do_interleave_unroll_k_sub(): # 4th fma self._emit(mfma_step_mxn(1, 1, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n//2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") # (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2 if i_k == unroll_k_sub - 1: self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") self._emit_empty_line() @@ -723,8 +744,8 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -763,8 +784,8 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2}") @@ -831,12 +852,13 @@ def mfma_loop_repeat_2x2(): self._emit(f"; do fma accumulate with unroll {unroll_k // k_per_inst}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2) )) # lds_width_n // 2 + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 - self._emit(f".itr_k = 0") - self._emit(f".rept {unroll_k // k_per_inst - 1}") - with self._indent_context(): + # self._emit(f".itr_k = 0") + # self._emit(f".rept {unroll_k // k_per_inst - 1}") + #with self._indent_context(): + for i_k in range( unroll_k // k_per_inst - 1): # 1st fma self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0)) @@ -846,21 +868,25 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # 3rd fma - self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + # self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst))) self._emit(f's_waitcnt lgkmcnt(1)') self._emit(mfma_step_mxn(1, 0)) # 4th fma - self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + # self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((i_k+1)* k_per_inst))) self._emit(mfma_step_mxn(1, 1)) self._emit_empty_line() # last - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) - self._emit('.itr_k = .itr_k + 1') + # self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) + # self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((i_k+1) * k_per_inst, lds_width_n//2))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst, lds_width_m//2))) + # self._emit('.itr_k = .itr_k + 1') - self._emit(f".endr") + # self._emit(f".endr") self._emit_empty_line() self._emit(f"; last unroll") self._emit(f"v_xor_b32 v[{v_sld_b_os()}], {lds_single_size}, v[{v_sld_b_os()}] ; switch double buffer b load") @@ -923,9 +949,10 @@ def mfma_loop_repeat_2x2(): self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) - self._emit(f".itr_k = 0") - self._emit(f".rept {unroll_k // k_per_inst - 1}") - with self._indent_context(): + # self._emit(f".itr_k = 0") + # self._emit(f".rept {unroll_k // k_per_inst - 1}") + # with self._indent_context(): + for i_k in range(unroll_k // k_per_inst - 1): # 1st fma self._emit('s_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0)) @@ -935,18 +962,22 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # 3rd fma - self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + # self._emit(f_sld_a(v_a(), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}')) + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((i_k+1)* k_per_inst))) self._emit('s_waitcnt lgkmcnt(1)') self._emit(mfma_step_mxn(1, 0)) # 4th fma - self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + # self._emit(f_sld_b(v_b(), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}')) + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((i_k+1)* k_per_inst))) self._emit(mfma_step_mxn(1, 1)) self._emit_empty_line() # last - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + #self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), f'{lds_base_n}+(.itr_k+1)*{lds_width_n * k_per_inst}+{lds_width_n//2}')) + #self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), f'{lds_base_m}+(.itr_k+1)*{lds_width_m * k_per_inst}+{lds_width_m//2}')) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((i_k+1) * k_per_inst, lds_width_n//2))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((i_k+1) * k_per_inst, lds_width_m//2))) self._emit('.itr_k = .itr_k + 1') self._emit('.endr') self._emit_empty_line() @@ -991,7 +1022,7 @@ def mfma_loop_repeat_2x1_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2) )) # lds_width_m // 2 def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -1001,33 +1032,33 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_n * k_per_inst + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {0}") # lds_width_m * k_per_inst if unroll_k_sub == 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst, lds_width_m // 2) ) + \ + f" ; load i_k:{1} into local buffer {1}, repeat {1}") # lds_width_m * k_per_inst + lds_width_m // 2 elif i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") # (2*i_k+1) * lds_width_m * k_per_inst + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") # (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 else: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ - f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ + f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") # (2*i_k+1) * lds_width_m * k_per_inst self._emit_empty_line() # 2nd fma self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1035,11 +1066,11 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1047,11 +1078,11 @@ def do_unroll_k_sub(): self._emit(mfma_step_mxn(1, 0, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() do_unroll_k_sub() @@ -1106,7 +1137,7 @@ def do_unroll_k_sub(): self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1146,20 +1177,20 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(k_per_inst,lds_width_m // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1167,12 +1198,12 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(1, 0, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+1) * lds_width_m * k_per_inst + lds_width_m // 2 ) + \ + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+1) * k_per_inst, lds_width_m // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1180,11 +1211,11 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1192,11 +1223,11 @@ def do_interleave_unroll_k_sub(): self._emit(mfma_step_mxn(1, 0, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst, lds_width_m // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst + lds_width_m // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst, lds_width_m // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() return self._get_deferred() @@ -1263,7 +1294,7 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -1298,7 +1329,7 @@ def do_interleave_share_store(): self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1346,7 +1377,7 @@ def mfma_loop_repeat_1x2_lp2(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) def do_unroll_k_sub(): unroll_k_sub = (unroll_k // k_per_inst) // 2 - 1 @@ -1356,20 +1387,20 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1377,12 +1408,12 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1390,11 +1421,11 @@ def do_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1402,11 +1433,11 @@ def do_unroll_k_sub(): self._emit(mfma_step_mxn(0, 1, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n // 2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() do_unroll_k_sub() @@ -1461,7 +1492,7 @@ def do_unroll_k_sub(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) do_unroll_k_sub() self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma @@ -1500,20 +1531,20 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({1 if i_k == 0 else 2})') self._emit(mfma_step_mxn(0, 0, 0, 0)) if i_k == 0: - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + lds_width_m * k_per_inst) + \ + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_m(k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst) + \ + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {0}") self._emit_empty_line() @@ -1521,12 +1552,12 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 1, 0, 0)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") else: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+1) * lds_width_n * k_per_inst + lds_width_n // 2 ) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{2*i_k+1} into local buffer {1}, repeat {1}") - self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + (2*i_k+2) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m + mi_m((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() self._emit(f"; k iteration : {(2 * i_k + 1) * k_per_inst}") @@ -1534,11 +1565,11 @@ def do_interleave_unroll_k_sub(): self._emit(f's_waitcnt lgkmcnt({2 if i_k != unroll_k_sub - 1 else 3})') self._emit(mfma_step_mxn(0, 0, 1, 1)) if i_k == unroll_k_sub - 1: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{(2*i_k+3)} into local buffer {1}, repeat {0}") else: - self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") + self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst)) + f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {0}") self._emit_empty_line() # 2nd fma @@ -1546,11 +1577,11 @@ def do_interleave_unroll_k_sub(): self._emit(mfma_step_mxn(0, 1, 1, 1)) if i_k == unroll_k_sub - 1: # v_b attension! - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n // 2) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") else: - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + (2*i_k+2) * lds_width_n * k_per_inst + lds_width_n // 2)+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + (2*i_k+3) * lds_width_m * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+2) * k_per_inst, lds_width_n // 2))+ f" ; load i_k:{2*i_k+2} into local buffer {0}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() return self._get_deferred() @@ -1615,7 +1646,7 @@ def do_interleave_share_store(): self._emit(f"; do fma accumulate with unroll {unroll_k}") self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) if (unroll_k // k_per_inst) // 2 - 1 != 0: mbb_list_sub = [create_machine_basic_block(do_interleave_unroll_k_sub(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -1650,7 +1681,7 @@ def do_interleave_share_store(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) self._emit(do_interleave_unroll_k_sub()) self._emit(f"; k iteration : {unroll_k - 2 * k_per_inst}") # 1st fma From c9c97cb5eecb7118320e9098706f58b909fcf748 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 24 Jan 2021 23:19:02 +0800 Subject: [PATCH 10/40] tiny fix kpack --- igemm/algo/mfma_main_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 9dbb2211..85e96527 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -1536,7 +1536,7 @@ def do_interleave_unroll_k_sub(): self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n(k_per_inst)) + \ f" ; load i_k:{1} into local buffer {1}, repeat {0}") if unroll_k_sub == 1: - self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_m(k_per_inst, lds_width_n // 2 )) + \ + self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(k_per_inst, lds_width_n // 2 )) + \ f" ; load i_k:{1} into local buffer {1}, repeat {1}") elif i_k == unroll_k_sub - 1: self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+1) * k_per_inst)) + \ From ad0b413cf172539258ac0308b5271559fd14f702 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 28 Jan 2021 16:09:43 +0800 Subject: [PATCH 11/40] nxe=0 now works! --- config/igemm_fwd_gtc_gfx908_nhwc.config | 23 + igemm/algo/global_memory.py | 14 +- igemm/algo/igemm_fwd_gtc_nhwc.py | 906 ++++++++++++------------ igemm/algo/mfma_main_loop.py | 62 +- igemm/algo/shared_memory.py | 5 +- igemm/algo/xdlops_mapping.py | 28 +- igemm/codegen/macro.py | 5 + igemm/codegen/mbb.py | 29 +- 8 files changed, 576 insertions(+), 496 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 25b44581..1aa2093d 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -24,6 +24,29 @@ tensor_layout = 'nhwc' nxb = 4 nxe = 0 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 4 +nxe = 0 + #--------------------------- 256x128 [igemm_fwd_gtc] gemm_m_per_block = 256 diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index d3abca2a..68936eb0 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -107,6 +107,7 @@ def __init__(self): self.precision = 'fp32' # 'fp32', 'fp16', ... self.src_order = 0 # 0-d0xd1, 1-d1xd0 self.dst_order = 0 # 0-d0xd1, 1-d1xd0 + self.bfe_flag = 0 class macro_igemm_2d_global_load_t(macro_base_t): # TODO: if need vectorize further LDS write, need shuffle dst gpr while load @@ -378,8 +379,11 @@ def __init__(self, mc, ctrl, inline = False): self.ctrl = ctrl self.declare_arg("v_dst") self.declare_arg("s_ptr") + self.declare_arg("s_os") self.declare_arg("v_os") self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") def name(self): ctrl = self.ctrl @@ -412,10 +416,14 @@ def expr(self): for i_d0 in range(ctrl.length_d0): for i_d1 in range(n_d1): if self.v_flag != None: - self._emit(f"v_cmpx_eq_u32 vcc, 1, v[{self.v_flag(i_cnt)}]") - self._emit(buffer_load_dword(f"{self.v_dst()}+{i_cnt*ctrl.vector_d1}", f"{self.v_os(i_cnt)}", f"{self.s_ptr()}", 0, 0)) + if ctrl.bfe_flag: + self._emit(f"v_bfe_u32 v[{self.v_tmp()}], v[{self.v_flag()}], {i_cnt}, 1") + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_tmp()}]") + else: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_cnt)}]") + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_cnt*ctrl.vector_d1}", f"{self.v_os(i_cnt)}", f"{self.s_ptr()}", f"{self.s_os()}", 0)) if self.v_flag != None: - self._emit(f"s_mov_b32 exec, -1") + self._emit(f"s_mov_b64 exec, -1") i_cnt += 1 def get_issues(self): diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index e5786b09..a2f23998 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -35,6 +35,8 @@ from .coalescing_store import * from .mfma_main_loop import * +IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG = 0 + def _find_non_1_index_in_list(list_object): result_list = list() for idx, item in enumerate(list_object): @@ -67,7 +69,6 @@ def __init__(self, mc, tunable): assert self.in_thread_copy_ndim in (0, 1, 2) assert self.wei_thread_copy_ndim in (0, 1, 2) - self.coalescing_store_groups = igemm_next_pow2(self.tunable.coalescing_store_groups) if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: assert (self.tunable.gemm_m_per_thread * self.tunable.gemm_m_repeat) % self.coalescing_store_groups == 0, \ @@ -123,7 +124,6 @@ def flatten(x): ctrl_coalescing_store_xdlops.adjust_optimal_coalescing_groups() # in m1_m0 order, must adjust self.coalescing_store = igemm_coalescing_store_xdlops_t(mc, ctrl_coalescing_store_xdlops) - self.label_out = f"L_{self.name()}_out" self.dict_shifted_stride = dict() @@ -143,46 +143,6 @@ def try_shift_stride(self, gpr, shifter): self.dict_shifted_stride[gpr.label] = gpr self._emit(f"s_lshl_b32 s[{gpr()}], s[{gpr()}], {shifter}") return self._get_deferred() - - # will not support order, since nhwc fix order is enough - ''' - def get_lds_gemm_m_gemm_n_order(self): - def need_reverse_order(x0, x1): - if x0 != 1 and x1 == 1: - return True - if x0 > x1: - return True - return False - - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() - - gemm_n_order = -1 # gemm_n order is not supported - - gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B - if self.tunable.allow_lds_reorder: - if need_reverse_order(ta_nb0, ta_nb1): - gemm_m_order = IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0 - assert False, "maybe not correct" - - return gemm_m_order, gemm_n_order - ''' - - class macro_set_flag_hw(macro_base_t): - def __init__(self, mc, inline = False): - macro_base_t.__init__(self, mc, inline) - self.declare_arg("v_flag") - self.declare_arg("v_ih") - self.declare_arg("v_iw") - self.declare_arg("s_h") - self.declare_arg("s_w") - def name(self): - return '.v_fwd_gtc_nhwc_set_flag_hw' - - def expr(self): - self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_h()}], v[{self.v_ih()}]") - self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, 1, vcc") - self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") - self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") class macro_set_flag_nhw(macro_base_t): def __init__(self, mc, inline = False): @@ -202,118 +162,140 @@ def expr(self): self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_w()}], v[{self.v_iw()}]") self._emit(f"v_cndmask_b32 v[{self.v_flag()}], 0, v[{self.v_flag()}], vcc") - class macro_in_update_hw_t(macro_base_t): - def __init__(self, mc, inline = False): - macro_base_t.__init__(self, mc, inline) - self.declare_arg("v_in_ihi") - self.declare_arg("v_in_iwi") - self.declare_arg("v_in_iho") - self.declare_arg("v_in_iwo") - self.declare_arg("v_in_iy") - self.declare_arg("v_in_ix") - self.declare_arg("s_dilation_h") - self.declare_arg("s_dilation_w") - def name(self): - return '.v_fwd_gtc_nhwc_in_update_hw' - - def expr(self): - self._emit(f"; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h, here make sure iho <- iho * s_stride_h - s_pad_h before hand") - self._emit(f"; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w, here make sure iwo <- iwo * s_stride_w - s_pad_w before hand") - self._emit(f"v_mad_i32_i24 v[{self.v_in_ihi()}], s[{self.s_dilation_h()}], v[{self.v_in_iy()}], v[{self.v_in_iho()}]") - self._emit(f"v_mad_i32_i24 v[{self.v_in_iwi()}], s[{self.s_dilation_w()}], v[{self.v_in_ix()}], v[{self.v_in_iwo()}]") + class macro_move_slice_window_block_wise_1x1_t(macro_base_t): + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + self.declare_arg("v_wei_os") + self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight + self.options = options - class macro_in_update_os_t(macro_base_t): - def __init__(self, mc, data_byte, inline = False): - macro_base_t.__init__(self, mc, inline) - self.data_byte = data_byte - self.declare_arg("v_in_os") - self.declare_arg("v_in_os_base") - self.declare_arg("v_in_ihi") - self.declare_arg("v_in_iwi") - self.declare_arg("s_wi") - self.declare_arg("s_in_stride_wi") - self.declare_arg("v_tmp") def name(self): - return '.v_fwd_gtc_nhwc_in_update_os' + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_1x1' def expr(self): - self._emit(f"v_mad_u32_u24 v[{self.v_tmp()}], v[{self.v_in_ihi()}], s[{self.s_wi()}], v[{self.v_in_iwi()}]") - self._emit(f"v_mul_lo_u32 v[{self.v_tmp()}], s[{self.s_in_stride_wi()}], v[{self.v_tmp()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], v[{self.v_tmp()}], v[{self.v_in_os_base()}]") + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + self._emit_empty_line() - class macro_move_slice_window_k_e1_c_t(macro_base_t): + class macro_move_slice_window_block_wise_t(macro_base_t): ''' nhwc gemm_k = e*c, and thread/cluster length for e is always 1 hence always move along c and accumulate into e this macro is for input and weight together. + block-wise move slice window, means we increase y*x*c using sgpr. + Indeed this is always true, since gemm_k % k_per_block == 0 always true. + Beside, we always increase along c dimension, this means y, x, c using sgpr is enough + ''' - def __init__(self, mc, tunable, inline = False): - macro_base_t.__init__(self, mc, inline) + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) self.tunable = tunable - self.declare_arg("v_move_slice_k_iy") - self.declare_arg("v_move_slice_k_ix") - self.declare_arg("v_move_slice_k_ic") - #self.declare_arg("s_gemm_k_num_y") - self.declare_arg("s_gemm_k_num_x") - self.declare_arg("s_gemm_k_num_c") - self.declare_arg("s_move_slice_k_c") - self.declare_arg("v_in_os") + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset self.declare_arg("v_wei_os") - - # self.declare_arg("s_in_stride_gemm_k_num_c") - self.declare_arg("s_move_slice_k_in_stride_diff_y") # indeed stride_y - stride_x, always possitive - self.declare_arg("s_move_slice_k_in_stride_diff_x") # indeed stride_x - stride_c, always possitive - self.declare_arg("s_move_slice_k_stride_c") # this is indeed s_move_slice_k_c * data_byte, same for input/weight - - self.declare_arg("v_in_ihi") # need update - self.declare_arg("v_in_iwi") # need update - # self.declare_arg("s_dilation_h") - # self.declare_arg("s_dilation_w") - self.declare_arg("s_in_diff_hi") # s_dilation_h - self.declare_arg("s_in_diff_wi") # s_dilation_w - self.declare_arg("s_in_diff_sub_wi") # total wi needed to be deduced from iwi, when carry-on + self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight + self.declare_arg("s_gemm_k_num_c") # c * data_byte + self.declare_arg("s_flag_need_acc_yx") + self.options = options def name(self): - return '.v_fwd_gtc_nhwc_move_slice_window_k_e1_c' + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise' def expr(self): - self._emit(f"v_add_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_move_slice_k_c()}], v[{self.v_move_slice_k_ic()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_in_os()}]") - self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") # weight offset always increase, treat y*x*c as single dimension - self._emit(f"v_cmpx_le_u32 vcc, s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") - self._emit(f"v_subrev_u32 v[{self.v_move_slice_k_ic()}], s[{self.s_gemm_k_num_c()}], v[{self.v_move_slice_k_ic()}]") - self._emit(f"v_add_u32 v[{self.v_move_slice_k_ix()}], 1, v[{self.v_move_slice_k_ix()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_in_stride_diff_x()}], v[{self.v_in_os()}]") # merge with above c - self._emit(f"v_add_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_wi()}], v[{self.v_in_iwi()}]") - self._emit(f"s_mov_b64 exec, -1") - self._emit_empty_line() - self._emit(f"v_cmpx_le_u32 vcc s[{self.s_gemm_k_num_x()}], v[{self.v_move_slice_k_ix()}]") - self._emit(f"v_add_u32 v[{self.v_move_slice_k_iy()}], 1, v[{self.v_move_slice_k_iy()}]") - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_in_stride_diff_y()}], v[{self.v_in_os()}]") - self._emit(f"v_subrev_u32 v[{self.v_in_iwi()}], s[{self.s_in_diff_sub_wi()}], v[{self.v_in_iwi()}]") - self._emit(f"v_add_u32 v[{self.v_in_ihi()}], s[{self.s_in_diff_hi()}], v[{self.v_in_ihi()}]") - self._emit(f"s_mov_b64 exec, -1") + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") + self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 0, 1") self._emit_empty_line() - # free of last dim check - class macro_move_slice_window_k_nxe0_c_t(macro_base_t): + class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): ''' - used for nxe=0. only c move is needed + can not inline + prefer to put this before global load wait. And for simplicity, no auto schedule. ''' - def __init__(self, mc, tunable, inline = False): - macro_base_t.__init__(self, mc, inline) + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) self.tunable = tunable + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset self.declare_arg("v_in_os") - self.declare_arg("v_wei_os") - self.declare_arg("s_move_slice_k_stride_c") # this is indeed s_move_slice_k_c * data_byte - + self.declare_arg("v_in_ihi_list") + self.declare_arg("v_in_iwi_list") + self.declare_arg("v_in_flag") + if not IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self.declare_arg("v_in_flag_n") + self.declare_arg("s_flag_need_acc_yx") + self.declare_arg("s_move_slice_k_ix") + self.declare_arg("s_x") + self.declare_arg("s_in_diff_hi") # this is s_dilation_h * s_in_stride_hi - (x - 1) * s_dilation_w * s_in_stride_wi, always possitive + self.declare_arg("s_in_diff_wi") # this is s_dilation_w * s_in_stride_wi + self.declare_arg("s_dilation_h") + self.declare_arg("s_dilation_w") + self.declare_arg("s_dilation_w_x") # this is -1* (x - 1) * s_dilation_w + self.declare_arg("s_hi") + self.declare_arg("s_wi") + self.declare_arg("v_tmp") # 2 needed + self.declare_arg("s_tmp") + self.options = options def name(self): - return '.v_fwd_gtc_nhwc_move_slice_window_k_nxe0_c' + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_acc_yx' def expr(self): - self._emit(f"v_add_u32 v[{self.v_in_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_in_os()}]") - self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") + assert "label_acc_yx" in self.options + label_acc_yx = self.options["label_acc_yx"] + '_{}'.format(self.expr_cnt) + label_acc_yx_end = self.options["label_acc_yx"] + '_end' + '_{}'.format(self.expr_cnt) + label_acc_yx_x_end = self.options["label_acc_yx"] + '_x_end' + '_{}'.format(self.expr_cnt) + + assert "nb_per_thread" in self.options + nb_per_thread = self.options["nb_per_thread"] + + assert 'm_set_flag_nhw' in self.options + m_set_flag_nhw = self.options['m_set_flag_nhw'] + + self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") + self._emit(f"s_cbranch_scc0 {label_acc_yx_end} ; no need do accumulate yx") + self._emit_front(f"{label_acc_yx}:") + self._emit(f"s_mov_b32 s[{self.s_in_offset()}], 0") # reset input offset. wei, no care + ''' + ix accumulate, will only accumulate in width, and will never carry on to height + iy accumulate, will only accumulate in height, and will never carry on to batch + this makes life easier + ''' + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"s_add_u32 s[{self.s_move_slice_k_ix()}], 1, s[{self.s_move_slice_k_ix()}]") + self._emit(f"s_cmp_le_u32 s[{self.s_x()}], s[{self.s_move_slice_k_ix()}]") + + # update iwi + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_dilation_w()}], s[{self.s_dilation_w_x()}]") + for i in range(nb_per_thread): + self._emit(f"v_add_u32 v[{self.v_in_iwi_list(i)}], s[{self.s_tmp()}], v[{self.v_in_iwi_list(i)}]") + + # update in_os + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_in_diff_wi()}], s[{self.s_in_diff_hi()}]") + for i in range(nb_per_thread): + self._emit(f"v_add_u32 v[{self.v_in_os(i)}], s[{self.s_tmp()}], v[{self.v_in_os(i)}]") + + # update ihi, accumulate + self._emit(f"s_cbranch_scc0 {label_acc_yx_x_end}") + self._emit(f"s_mov_b32 s[{self.s_move_slice_k_ix()}], 0") + for i in range(nb_per_thread): + self._emit(f"v_add_i32 v[{self.v_in_ihi_list(i)}], s[{self.s_dilation_h()}], v[{self.v_in_ihi_list(i)}]") + self._emit_front(f"{label_acc_yx_x_end}:") + + # now set flags + for i in range(nb_per_thread): + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_bfe_u32 v[{self.v_tmp(1)}], v[{self.v_in_flag()}], {16 + i}, 1 ; extract flag_n") + self._emit(f"v_and_b32 v[{self.v_in_flag()}], {0xffffffff ^ (1<1 in e dimension" # it's no point to have both x0, x1 have copy value - assert ta_nb0 != 1 and ta_nb1 != 1 - assert tb_k0 != 1 and tb_k1 != 1 + assert not (ta_nb0 != 1 and ta_nb1 != 1) + assert not (tb_k0 != 1 and tb_k1 != 1) return ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 # M, K, N @@ -790,6 +736,10 @@ def get_macro_global_load(self): else: assert False + if self.tunable.nxe != 0: + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + ctrl_in_gld.bfe_flag = 1 + if self.tunable.precache_soffset: return macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) @@ -824,28 +774,27 @@ def get_macro_shared_store(self): inline = True if self.tunable.fma_interleave else False return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) - # computation macro - def get_macro_in_update_hw(self): - inline = True if self.tunable.fma_interleave else False - return self.macro_in_update_hw_t(self.mc, inline) - - def get_macro_in_update_os(self): - inline = True if self.tunable.fma_interleave else False - return self.macro_in_update_os_t(self.mc, amdgpu_precision_data_byte(self.tunable.precision), inline) - def get_macro_move_slice_window(self): inline = True if self.tunable.fma_interleave else False if self.tunable.nxe != 0: - move_slice_window = self.macro_move_slice_window_k_e1_c_t(self.mc, self.tunable, inline) + move_slice_window = self.macro_move_slice_window_block_wise_t(self.mc, self.tunable, inline) else: - move_slice_window = self.macro_move_slice_window_k_nxe0_c_t(self.mc, self.tunable, inline) + move_slice_window = self.macro_move_slice_window_block_wise_1x1_t(self.mc, self.tunable, inline) # return single functor ! return move_slice_window - def get_macro_set_flag_hw(self): + def get_macro_move_slice_window_accumulate(self): inline = True if self.tunable.fma_interleave else False - return self.macro_set_flag_hw(self.mc, inline) + if self.tunable.nxe != 0: + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + return self.macro_move_slice_window_block_wise_acc_yx_t(self.mc, self.tunable, inline, + label_acc_yx = self.name() + "_acc_yx", + nb_per_thread = nb_per_thread, + m_set_flag_nhw = self.get_macro_set_flag_nhw()) + else: + return None def get_macro_set_flag_nhw(self): inline = True if self.tunable.fma_interleave else False @@ -863,13 +812,13 @@ def get_symbol_global_load_s_stride_d0_d1(self): in_stride_gprs = [s_dummy, s_dummy, s_dummy, - s.s_stride_c] + s_dummy] # [tb_k0, tb_k1, ta_e, ta_c] wei_stride_gprs = [s.s_wei_stride_k0 if tb_k0 != 1 else s_dummy, s.s_wei_stride_k if tb_k1 != 1 else s_dummy, s_dummy, - s.s_stride_c] + s_dummy] if self.in_thread_copy_ndim == 2: s_in_stride_d0 = in_stride_gprs[in_thread_copy_index[0]] @@ -1016,10 +965,6 @@ def emit_kernel_prologue(self): data_byte = amdgpu_precision_data_byte(self.tunable.precision) - m_in_update_hw = self.get_macro_in_update_hw() - m_in_update_os = self.get_macro_in_update_os() - - m_set_flag_hw = self.get_macro_set_flag_hw() m_set_flag_nhw = self.get_macro_set_flag_nhw() s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.get_symbol_global_load_s_stride_d0_d1() @@ -1039,6 +984,8 @@ def emit_kernel_prologue(self): s_dummy = sym_t("s_dummy") + k_pack = ta_c # always use this as k_pack + # start emit self._emit(f"s_load_dwordx2 s[{s.s_p_in((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_in()}") self._emit(f"s_load_dwordx2 s[{s.s_p_wei((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_wei()}") @@ -1066,16 +1013,12 @@ def emit_kernel_prologue(self): self._emit(f"; wei(e, c, k0, k1) thread_length: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, cluster_length: {ca_e}x{ca_c}x{cb_k0}x{cb_k1}") # weight ic same as input self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") - self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k, tb_k, True)) + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) self._emit_empty_line() self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") - if self.tunable.nxe != 0: - self._emit(f"v_mov_b32 v[{v.v_in_iy()}], 0") - self._emit(f"v_mov_b32 v[{v.v_in_ix()}], 0") - self._emit(f"s_waitcnt lgkmcnt(0)") self._emit_empty_line() if IGEMM_GTC_FEAT_MAGIC_DIVISION: @@ -1093,13 +1036,13 @@ def emit_kernel_prologue(self): # weight if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_wei_stride_y()}], s[{s.s_x()}], s[{s.s_c()}]") - self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_wei_stride_y()}], s[{s.s_y()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_x()}], s[{s.s_c()}]") + self._emit(f"s_mul_i32 s[{s.s_wei_stride_k()}], s[{s.s_tmp()}], s[{s.s_y()}]") else: self._emit(f"s_mov_b32 s[{s.s_wei_stride_k()}], s[{s.s_c()}]") if tb_k0 != 1: - self._emit(f"s_lshl_b32 s[{s.s_wei_stride_K0()}], s[{s.s_wei_stride_K()}], {igemm_log2(nb_k1)}") + self._emit(f"s_lshl_b32 s[{s.s_wei_stride_k0()}], s[{s.s_wei_stride_k()}], {igemm_log2(nb_k1)}") # output self._emit(f"s_mul_i32 s[{s.s_out_stride_wo()}], s[{s.s_k()}], s[{s.s_group()}]") @@ -1113,25 +1056,25 @@ def emit_kernel_prologue(self): # pad gemm_m, gemm_n if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_ho()}], s[{s.s_wo()}]") + self._emit(f"s_mul_i32 s[{s.s_dim_br()}], s[{s.s_ho()}], s[{s.s_wo()}]") else: - self._emit(f"s_mul_i32 s[{s.s_dim_b()}], s[{s.s_hi()}], s[{s.s_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_dim_br()}], s[{s.s_hi()}], s[{s.s_wi()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_n()}], s[{s.s_dim_b()}]") - self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_tmp(2)}]") + self._emit(f"s_mul_i32 s[{s.s_dim_mr()}], s[{s.s_n()}], s[{s.s_dim_br()}]") + self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_m_per_block - 1}, s[{s.s_dim_mr()}]") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") - self._emit(f"s_lshl_b32 s[{s.s_dim_m()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_mp()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_m_per_block)}") self._emit(f"s_add_u32 s[{s.s_tmp()}], {self.tunable.gemm_n_per_block - 1}, s[{s.s_k()}]") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_tmp()}], {igemm_log2(self.tunable.gemm_n_per_block)}") - self._emit(f"s_lshl_b32 s[{s.s_dim_n()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_dim_np()}], s[{s.s_tmp(1)}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit_empty_line() self._emit(f"; gemm_m_per_block:{self.tunable.gemm_m_per_block}, gemm_n_per_block:{self.tunable.gemm_n_per_block}, source_access_order:{self.tunable.source_access_order}") # calculate group index - self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_m()}], {igemm_log2(self.tunable.gemm_m_per_block)}") - self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_n()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_mp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080010 ; offset:16, width:8") @@ -1143,7 +1086,7 @@ def emit_kernel_prologue(self): self._emit(f"s_mov_b32 s[{s.s_bx()}], s[{s.s_tmp(4)}]") if self.tunable.source_access_order == IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_n()}], {igemm_log2(self.tunable.gemm_n_per_block)}") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) @@ -1151,7 +1094,7 @@ def emit_kernel_prologue(self): self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_tmp(5), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) else: - self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_m()}], {igemm_log2(self.tunable.gemm_m_per_block)}") + self._emit(f"s_lshr_b32 s[0], s[{s.s_dim_mp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080000 ; offset:0, width:8") self._emit(m_mdiv_u32_ss(s.s_tmp(5), s.s_tmp(4), s.s_bx(), s.s_magic_0(), s.s_tmp(3), '0', s.s_tmp())) @@ -1167,78 +1110,89 @@ def emit_kernel_prologue(self): if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_in_iwo(), v.v_in_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - self._emit(f"v_mul_lo_u32 v[{v.v_in_iho()}], s[{s.s_stride_h()}], v[{v.v_in_iho()}]") - self._emit(f"v_sub_i32 v[{v.v_in_ihi()}], v[{v.v_in_iho()}], s[{s.s_pad_h()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_in_iwo()}], s[{s.s_stride_w()}], v[{v.v_in_iwo()}]") - self._emit(f"v_sub_i32 v[{v.v_in_iwi()}], v[{v.v_in_iwo()}], s[{s.s_pad_w()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_ihi_list(0)}], s[{s.s_stride_h()}], v[{v.v_in_ihi_list(0)}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi_list(0)}], v[{v.v_in_ihi_list(0)}], s[{s.s_pad_h()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_iwi_list(0)}], s[{s.s_stride_w()}], v[{v.v_in_iwi_list(0)}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwi_list(0)}], v[{v.v_in_iwi_list(0)}], s[{s.s_pad_w()}]") self._emit_empty_line() else: if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_in_iwi(), v.v_in_ihi(), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) - ''' - from here, need track ihi, iwi in move slice window - ''' - - # update flag for batch size - self._emit(f"v_cmp_gt_u32 vcc, s[{self.s_n()}], v[{self.v_in_in()}]") - self._emit(f"v_cndmask_b32 v[{self.v_flag_n()}], 0, 1, vcc") + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshlrev_b32 v[{v.v_in_flag(0)}], 16, v[{v.v_tmp()}]") + else: + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshlrev_b32 v[{v.v_in_flag_n()}], 0, v[{v.v_tmp()}]") + self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ig()}], s[{s.s_block_gtc_ig()}], {igemm_log2(data_byte)}") self._emit(f"; calculate in offset") + self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") # compute group distance - self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ig()}], s[{s.s_block_gtc_ig()}], {igemm_log2(data_byte)}") self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") - self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_add_u32 s[{s.s_p_in(0)}], s[{s.s_p_in(0)}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") self._emit_empty_line() self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") # s_in_stride_wi need shift before! self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) - + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi()}]") - self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi()}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") if self.tunable.nxe != 0: - self._emit(m_set_flag_nhw(v.v_in_flag(), v.v_flag_n(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag()}], 16, 1") + self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], 0, v[{v.v_in_flag()}]") + else: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag_n()}], 0, 1") + self._emit(m_set_flag_nhw(v.v_in_flag(0), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) self._emit_empty_line() - if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_len_h()}], s[{s.s_ho()}], s[{s.s_stride_h()}]") - self._emit(f"s_mul_i32 s[{s.s_len_w()}], s[{s.s_wo()}], s[{s.s_stride_w()}]") - self._emit(f"s_mov_b32 s[{s.s_lim_h()}], s[{s.s_len_h()}]") - self._emit(f"s_mov_b32 s[{s.s_lim_w()}], s[{s.s_len_w()}]") + # if self.tunable.nxe != 0: + # self._emit(f"s_mul_i32 s[{s.s_len_h()}], s[{s.s_ho()}], s[{s.s_stride_h()}]") + # self._emit(f"s_sub_i32 s[{s.s_lim_h()}], s[{s.s_len_h()}], s[{s.s_pad_h()}]") + # self._emit(f"s_mul_i32 s[{s.s_len_w()}], s[{s.s_wo()}], s[{s.s_stride_w()}]") + # self._emit(f"s_sub_i32 s[{s.s_lim_w()}], s[{s.s_len_w()}], s[{s.s_pad_w()}]") # voffset if ta_nb0 != 1 or ta_nb1 != 1: + ''' thread_stride = na_nb1 if ta_nb0 != 1 else 1 self._emit(f"s_mov_b32 s[{s.s_tmp(5)}], {thread_stride}") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), s.s_tmp())) + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), s.s_tmp())) self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") self._emit(m_mdiv_u32_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), s.s_tmp())) else: - self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_dim_b(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_wo(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_dim_br(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) if self.tunable.nxe != 0: self._emit(f"s_mul_i32 s[{s.s_thread_stride_h()}], s[{s.s_thread_stride_h()}], s[{s.s_stride_h()}]") @@ -1247,14 +1201,15 @@ def emit_kernel_prologue(self): # now let's precompute all the voffset # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v[{v.v_in_ihi()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v[{v.v_in_ihi_list()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp(4)}], v[{v.v_in_iwi_list()}]") self._emit(f"v_mov_b32 v[{v.v_tmp(3)}], v[{v.v_in_in()}]") nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 for i in range(1, nb_per_thread): # v_tmp+4:ihi, v_tmp+5:iwi - self._emit(f"v_add_i32 v[{v.v_tmp(4)}], s[{s.s_thread_stride_w()}], v[{v.v_in_iwi() if i == 1 else v.v_tmp(4) }]") + self._emit(f"v_add_i32 v[{v.v_tmp(4)}], s[{s.s_thread_stride_w()}], v[{v.v_tmp(4) }]") self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_w()}], v[{v.v_tmp(4)}]") - self._emit(f"v_subrev_i32 v[{v.v_tmp(4)}], s[{s.s_len_w()}], v[{v.v_tmp(4)}]") + self._emit(f"v_sub_i32 v[{v.v_tmp(4)}], v[{v.v_tmp(4)}], s[{s.s_len_w()}]") if self.tunable.nxe != 0: self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_stride_h()}], v[{v.v_tmp(5)}]") else: @@ -1263,37 +1218,96 @@ def emit_kernel_prologue(self): self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_thread_stride_h()}], v[{v.v_tmp(5)}]") self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_h()}], v[{v.v_tmp(5)}]") - self._emit(f"v_subrev_i32 v[{v.v_tmp(5)}], s[{s.s_len_h()}], v[{v.v_tmp(5)}]") - self._emit(f"v_add_u32 v[{v.v_tmp(3)}], 1, v[{v..v_tmp(3)}]") + self._emit(f"v_sub_i32 v[{v.v_tmp(5)}], v[{v.v_tmp(5)}], s[{s.s_len_h()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(3)}], 1, v[{v.v_tmp(3)}]") self._emit(f"s_mov_b64 exec, -1") self._emit(f"v_add_u32 v[{v.v_tmp(3)}], s[{s.s_thread_stride_n()}], v[{v.v_tmp(3)}]") + self._emit(f"v_mov_b32 v[{v.v_in_ihi_list(i)}], v[{v.v_tmp(5)}]") + self._emit(f"v_mov_b32 v[{v.v_in_iwi_list(i)}], v[{v.v_tmp(4)}]") + if self.tunable.nxe != 0: # update flag for batch size self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_tmp(3)}]") - self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") - self._emit(m_set_flag_nhw(v.v_flag(i), v.v_tmp(), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) - - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v.v_tmp(3)}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp(1)}], 0, 1, vcc") + # extra, store this into flag n + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp(1)}], {16 + i}, v[{v.v_in_flag()}]") + self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], {i}, v[{v.v_in_flag()}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag_n()}], v[{v.v_tmp(1)}], {i}, v[{v.v_in_flag_n()}]") + self._emit(m_set_flag_nhw(v.v_in_flag(i), v.v_tmp(1), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) + + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_tmp(3)}]") self._emit(f"v_add_lshl_u32 v[{v.v_tmp(2)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_tmp(5)}]") self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_in_os(i)}], v[{v.v_tmp(2)}], v[{v.v_tmp()}]") + ''' + thread_stride = na_nb1 if ta_nb0 != 1 else 1 + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + + for i in range(1, nb_per_thread): + self._emit(f"s_mov_b32 s1, {thread_stride * i}") + self._emit(f"v_add_u32 v[{v.v_tmp()}], s1, v[{v.v_in_inb()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_tmp()}]") + if self.tunable.nxe != 0: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + self._emit(f"v_mul_lo_u32 v[{v.v_in_ihi_list(i)}], s[{s.s_stride_h()}], v[{v.v_in_ihi_list(i)}]") + self._emit(f"v_sub_i32 v[{v.v_in_ihi_list(i)}], v[{v.v_in_ihi_list(i)}], s[{s.s_pad_h()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_in_iwi_list(i)}], s[{s.s_stride_w()}], v[{v.v_in_iwi_list(i)}]") + self._emit(f"v_sub_i32 v[{v.v_in_iwi_list(i)}], v[{v.v_in_iwi_list(i)}], s[{s.s_pad_w()}]") + self._emit_empty_line() + + else: + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) + + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(i)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(i)}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os(i)}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + # update flag for batch size + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], {16 + i}, v[{v.v_in_flag(0)}]") + self._emit(m_set_flag_nhw(v.v_tmp(1), v.v_tmp(), v.v_in_ihi_list(i), v.v_in_iwi_list(i), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp(1)}], {i}, v[{v.v_in_flag()}]") + else: + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") + self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag_n()}], v[{v.v_tmp()}], {i}, v[{v.v_in_flag_n()}]") + self._emit(m_set_flag_nhw(v.v_in_flag(i), v.v_tmp(), v.v_in_ihi_list(i), v.v_in_iwi_list(i), s.s_hi(), s.s_wi())) else: pass # load in self._emit(self.global_load_in()) self._emit_empty_line() - self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") - # config weight range - #self._emit("; config for weight range") - #self._emit(f"s_mul_i32 s[{s.s_p_wei(2)}], s[{s.s_wei_stride_k() if self.tunable.nxe != 0 else s.s_c()}], s[{s.s_k()}]") - #self._emit(f"s_lshl_b32 s[{s.s_p_wei(2)}], s[{s.s_p_wei(2)}], {igemm_log2(data_byte)}") self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") self._emit(f"; calculate wei offset") @@ -1301,10 +1315,10 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") + self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") - self._emit(f"v_add_u32 v[{v.v_cur_k()}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_cur_k()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(1)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(1)}]") self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") self._emit_empty_line() @@ -1326,28 +1340,30 @@ def emit_kernel_prologue(self): self._emit(self.thread_mapping(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) else: self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") - self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) + self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp(), k_pack)) self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) ''' gemm_k * gemm_m * k_pack ''' - self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") - if ca_nb1 == 1: - # TODO: remove this path, not possible go here - assert False + self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_tmp(2)}]") else: - if ca_nb0 == 1: - self._emit(f"v_mov_b32 v[{v.v_tmp()}], v[{v.v_in_inb()}]") - else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ta_in0()}], {igemm_log2(na_nb1)}, v[{v.v_in_inb()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1)}, v[{v.v_tmp()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_in_inb()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit_empty_line() - self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k}, {ca_e}x{ca_c}x{cb_k}") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k)}, v[{v.v_wei_ik()}]") + self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, {ca_e}x{ca_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_wei_ik()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_tmp(2)}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_wei_ik()}]") self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") self._emit_empty_line() @@ -1366,6 +1382,11 @@ def emit_kernel_prologue(self): self._emit(self.coalescing_store.init_co_sub_n_index(v.v_co_sub_n_index(), '0', v.v_tmp())) self._emit_empty_line() + if self.tunable.nxe != 0: + self._emit(f"v_add_u32 v[{v.v_tmp()}], s[{s.s_block_gtc_ik()}], v[{v.v_co_sub_n_index()}]") + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp()}]") + self._emit(f"v_cndmask_b32 v[{v.v_out_flag()}], 0, 1, vcc") + ''' a good news for nhwc and coalescing output is that, we can treat gemm_m (n*ho*wo) as a single dimension, and use sgpr to stride along this dimension. this is much easier @@ -1376,190 +1397,76 @@ def emit_kernel_prologue(self): self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") - self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") - self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + # self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") + # self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") + # self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") + # self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + # self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") self._emit_empty_line() self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_ik()}], {igemm_log2(data_byte)}") - self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp(3)}]") self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out()}+1], 0") self._emit_empty_line() - self._emit(f"; compute v_co_sub_m_index along nb0 x nb1 : {na_nb0}x{na_nb1}") - if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N0_N1B: - if na_nb1 != 1: - self._emit(f"v_and_b32 v[{v.v_out_in1b()}], {na_nb1 - 1}, v[{v.v_co_sub_m_index()}] ; => N1B") - if na_nb0 != 1: - self._emit(f"v_lshrrev_b32 v[{v.v_out_in0()}], {igemm_log2(na_nb1)}, v[{v.v_co_sub_m_index()}] ; => N0") - else: - assert False, "un implemented, should rarely be used" - else: - assert False + self._emit(self.try_shift_stride(s.s_out_stride_wo, igemm_log2(data_byte))) + # self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_wo()}], s[{s.s_block_gtc_inb()}]") + # self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_wo()}], s[{s.s_block_gtc_inb()}]") + # self._emit(f"s_add_u32 s[{s.s_p_out(0)}], s[{s.s_p_out(0)}], s[{s.s_tmp(0)}]") + # self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + # self._emit_empty_line() + self._emit(f"v_add_u32 v[{v.v_out_inb()}], s[{s.s_block_gtc_inb()}], v[{v.v_co_sub_m_index()}] ; total n*ho*wo") + self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_wo()}], v[{v.v_out_inb()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_tmp()}], {igemm_log2(data_byte)}, v[{v.v_co_sub_n_index()}]") + self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") - # TODO: extend tensor size, here vgpr only have 32bit - self._emit(f"; compute from nb1") - self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_in1b()}], v[{v.v_out_in1b()}]") - if self.tunable.nxe != 0: - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) - else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) - self._emit_empty_line() - else: - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_b(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) - else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_out_in1(), v.v_tmp(5), s.s_dim_b(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_out_iwo(), v.v_out_iho(), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) - self._emit_empty_line() - self._emit_empty_line() - self._emit(f"; add in_in0, in_in1") - if na_nb0 != 1: - #if gemm_m_unmerge_cluster == 0: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp(1)}], v[{v.v_out_in0()}], {igemm_log2(unmerge_sub_n1)}, v[{v.v_out_in1()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_n()}], v[{v.v_tmp(1)}]") - # else: - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_n()}], v[{v.v_out_in1()}]") - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_out_stride_n0()}], v[{v.v_out_in0()}]") - # self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_tmp()}], v[{v.v_tmp(1)}]") - else: - self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_n()}], v[{v.v_out_in1()}]") self._emit(f"; add i_k") - ## gemm_m_unmerge_cluster is always 0 - # if gemm_m_order == IGEMM_FWD_GTC_LDS_STORE_ORDER_GEMM_M_K0_K1: - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") - # else: - # if na_k0 == 1: - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") - # else: - # if na_k1 == 1: - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_co_sub_m_index()}]") - # else: - # self._emit(f"v_and_b32 v[{v.v_tmp()}], {na_k0 - 1}, v[{v.v_co_sub_m_index()}] ; => k0") - # self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(na_k0)}, v[{v.v_co_sub_m_index()}] ; => k1") - # self._emit(f"v_lshl_or_b32 v[{v.v_tmp(1)}], v[{v.v_tmp()}], {igemm_log2(na_k1)}, v[{v.v_tmp(1)}]") - # self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_out_stride_k()}], v[{v.v_tmp(1)}]") - - self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_co_sub_n_index()}]") # n, add to k - - self._emit(f"; add ho, wo") - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_k()}], s[{s.s_group()}] ; stride for wo") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], v[{v.v_out_iho()}]") - self._emit(f"v_add_u32 v[{v.v_tmp(2)}], v[{v.v_tmp(1)}], v[{v.v_out_iwo()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_tmp()}], v[{v.v_tmp(2)}]") - self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_out_os()}], {igemm_log2(data_byte)}, v[{v.v_out_os()}]") - if self.tunable.nxe != 0: - self._emit(m_set_flag_hw(v.v_out_flag(), v.v_out_iho(), v.v_out_iwo(), s.s_ho(), s.s_wo())) - self._emit(f"; move slice stride") - self._emit(f"s_mov_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}]") + self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") + if self.tunable.nxe != 0: - self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c()}], {na_c}") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_move_slice_k_c()}], {igemm_log2(data_byte)}") + self._emit(f"s_mov_b32 s[{s.s_tmp()}], {na_c}") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_tmp()}], {igemm_log2(data_byte)}") else: self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") if self.tunable.nxe != 0: - self._emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[{s.s_c()}], {igemm_log2(data_byte)}") - # diff_y, ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h - self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_wi()}], s[{s.s_tmp(4)}]") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_dilation_h()}], s[{s.s_tmp(4)}]") - self._emit(self.try_shift_stride(s.s_move_slice_k_in_stride_diff_y, igemm_log2(data_byte))) - self._emit(f"s_sub_u32 s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_move_slice_k_in_stride_diff_y()}], s[{s.s_tmp(2)}]") - # diff_x, iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w, hence need compute s_dilation_w per increase - self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_c()}], s[{s.s_group()}]") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_dilation_w()}], s[{s.s_tmp(4)}]") - self._emit(self.try_shift_stride(s.s_move_slice_k_in_stride_diff_x, igemm_log2(data_byte))) - self._emit(f"s_sub_u32 s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_move_slice_k_in_stride_diff_x()}], s[{s.s_tmp(2)}]") - - self._emit(f"s_mul_i32 s[{s.s_in_diff_sub_wi()}], s[{s.s_x()}], s[{s.s_dilation_w()}] ") - - # assert na_c0 * na_c1e == self.tunable.gemm_k_per_block and nb_c0 * nb_c1e == self.tunable.gemm_k_per_block + # s_in_diff_wi : s_dilation_w * s_in_stride_wi + # s_in_diff_hi : s_dilation_h * s_in_stride_hi - (x - 1) * s_dilation_w * s_in_stride_wi, always possitive + # s_dilation_w_x : -1* (x - 1) * s_dilation_w + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_ix()}], 0") + self._emit(f"s_mul_i32 s[{s.s_in_diff_wi()}], s[{s.s_dilation_w()}], s[{s.s_in_stride_wi()}]") # shifted + self._emit(f"s_mul_i32 s[{s.s_tmp(3)}], s[{s.s_x()}], 1") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_in_diff_wi()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_in_stride_wi()}], s[{s.s_wi()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_tmp(1)}], s[{s.s_dilation_h()}]") + self._emit(f"s_sub_i32 s[{s.s_in_diff_hi()}], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") + self._emit(f"s_mul_i32 s[{s.s_dilation_w_x()}], s[{s.s_dilation_w()}], s[{s.s_tmp(3)}]") + self._emit(f"s_mul_i32 s[{s.s_dilation_w_x()}], s[{s.s_dilation_w_x()}], -1") - # if self.tunable.nxe != 0: - # #assert na_c0 * na_c1e == nb_c0 * nb_c1e - # self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1e()}], {na_c0 * na_c1e}") - # if IGEMM_GTC_FEAT_MAGIC_DIVISION: - # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") - # self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_magic_2(), s.s_tmp(3), s.s_stride_c(), s.s_tmp())) - # self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") - # self._emit(m_mdiv_u32_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_magic_3(), s.s_tmp(3), s.s_x(), s.s_tmp())) - # else: - # self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_move_slice_k_c1(), s.s_move_slice_k_c1e(), s.s_stride_c(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) - # self._emit(m_int_div_rem_ss(s.s_move_slice_k_x(), s.s_move_slice_k_y(), s.s_tmp(4), s.s_x(), v.v_tmp(4), v.v_tmp(), s.s_tmp())) - # else: - # #assert na_c1e == nb_c1e - # #self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1()}], {nb_c1e}") - # self._emit(f"s_mov_b32 s[{s.s_move_slice_k_c1e()}], {na_c1e}") - # self._emit_empty_line() - - # m_move_slice_window_ta, m_move_slice_window_tb = self.get_macro_move_slice_window() + self._emit_empty_line() # if self.tunable.nxe != 0: - # # assert s.s_out_stride_k.label not in self.dict_shifted_stride and s.s_wei_stride_k.label not in self.dict_shifted_stride - # if s.s_stride_c.label not in self.dict_shifted_stride: - # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_c(), s.s_in_stride_c_c1(), - # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) - # else: - # self._emit(f"s_lshr_b32 s[{s.s_tmp(3)}], s[{s.s_stride_c()}], {utility_log2(data_byte)}") - # self._emit(m_move_slice_window_tb.init_stride_c(s.s_tmp(3), s.s_in_stride_c_c1(), - # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1())) + # # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_wei_stride_k, igemm_log2(data_byte))) + # # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) # else: - # if self.is_1d_move_slice_k(): - # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_hw(), s.s_in_stride_c_c1(), s.s_move_slice_k_c1e())) - # else: - # self._emit(m_move_slice_window_tb.init_stride_c(s.s_stride_hw(), s.s_in_stride_c_c1(), - # s.s_in_stride_c_c0_c1_diff(), s.s_move_slice_k_c1e())) - - - # if not self.is_1d_move_slice_k(): - # self._emit(f"s_mov_b32 s[{s.s_gemm_k_num_c1()}], {unmerge_sub_tb_c1}") - #if self.tunable.nxe != 0: - # self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_stride_c()}], s[{s.s_c()}]") - #else: - # self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_c()}]") - self._emit_empty_line() - - #self._emit(self.try_shift_stride(s.s_in_stride_c_c1, igemm_log2(data_byte))) - #self._emit(self.try_shift_stride(s.s_wei_stride_k_k1, igemm_log2(data_byte))) - #self._emit(self.try_shift_stride(s.s_in_stride_c_c0_c1_diff, igemm_log2(data_byte))) - #self._emit(self.try_shift_stride(s.s_wei_stride_k_k0_k1_diff, igemm_log2(data_byte))) - - if self.tunable.nxe != 0: - # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) - self._emit(self.try_shift_stride(s.s_wei_stride_k, igemm_log2(data_byte))) - # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) - else: - # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) - self._emit(self.try_shift_stride(s.s_c, igemm_log2(data_byte))) - # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) + # # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) + # self._emit(self.try_shift_stride(s.s_c, igemm_log2(data_byte))) + # # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) # self._emit(self.try_shift_stride(s.s_move_slice_k_c1e, igemm_log2(data_byte))) self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") - def emit_kernel_fma_main_loop(self): s = self.sgpr v = self.vgpr data_byte = amdgpu_precision_data_byte(self.tunable.precision) - # m_move_slice_window_ta, m_move_slice_window_tb = self.get_macro_move_slice_window() - m_move_slice_window = self.get_macro_move_slice_window() - m_set_flag_hw = self.get_macro_set_flag_hw() + m_move_slice_window = self.get_macro_move_slice_window() + m_move_slice_window_accumulate = self.get_macro_move_slice_window_accumulate() def move_slice_window_b(): ''' @@ -1567,20 +1474,70 @@ def move_slice_window_b(): ''' if self.tunable.nxe != 0: with self._deferred_context(): - self._emit(m_move_slice_window(v.v_move_slice_k_iy(), v.v_move_slice_k_ix(), v.v_move_slice_k_ic(), - s.s_gemm_k_num_x(), s.s_gemm_k_num_c(), s.s_move_slice_k_c(), v.v_in_os(), v.v_wei_os(), - s.s_move_slice_k_in_stride_diff_y(), s.s_move_slice_k_in_stride_diff_x(), s.s_move_slice_k_stride_c(), - v.v_in_ihi(), v.v_in_iwi(), s.s_dilation_h(), s.s_dilation_w(), s.s_in_diff_sub_wi())) - self._emit(m_set_flag_hw(v.v_in_flag(), v.v_in_ihi(), v.v_in_iwi(), s.s_hi(), s.s_wi())) + self._emit(m_move_slice_window( + s.s_in_offset(), + v.v_wei_os(), + s.s_move_slice_k_stride_c(), + s.s_gemm_k_num_c(), + s.s_flag_need_acc_yx())) return self._get_deferred() else: with self._deferred_context(): - self._emit(m_move_slice_window(v.v_in_os(), v.v_wei_os(),s.s_move_slice_k_stride_c())) + self._emit(m_move_slice_window( + s.s_in_offset(), + v.v_wei_os(), + s.s_move_slice_k_stride_c())) return self._get_deferred() def move_slice_window_a(): return '' + def move_slice_window_acc(): + if self.tunable.nxe == 0: + return '' + else: + with self._deferred_context(): + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(m_move_slice_window_accumulate( + s.s_in_offset(), + v.v_in_os(), + v.v_in_ihi_list(), + v.v_in_iwi_list(), + v.v_in_flag(), + s.s_flag_need_acc_yx(), + s.s_move_slice_k_ix(), + s.s_x(), + s.s_in_diff_hi(), + s.s_in_diff_wi(), + s.s_dilation_h(), + s.s_dilation_w(), + s.s_dilation_w_x(), + s.s_hi(), + s.s_wi(), + v.v_tmp(), + s.s_tmp())) + else: + self._emit(m_move_slice_window_accumulate( + s.s_in_offset(), + v.v_in_os(), + v.v_in_ihi_list(), + v.v_in_iwi_list(), + v.v_in_flag(), + v.v_in_flag_n(), + s.s_flag_need_acc_yx(), + s.s_move_slice_k_ix(), + s.s_x(), + s.s_in_diff_hi(), + s.s_in_diff_wi(), + s.s_dilation_h(), + s.s_dilation_w(), + s.s_dilation_w_x(), + s.s_hi(), + s.s_wi(), + v.v_tmp(), + s.s_tmp())) + return self._get_deferred() + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: fctrl = ctrl_fma_main_loop_t() fctrl.thread_m = self.tunable.thread_tile_m @@ -1641,19 +1598,24 @@ def move_slice_window_a(): fctrl.global_load_b_functor = self.global_load_in fctrl.shared_store_a_functor = self.shared_store_wei fctrl.shared_store_b_functor = self.shared_store_in + + ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + fctrl.lds_k_pack = ta_c + if ctrl_xdlops_mapping.wave_step_m == 1: fctrl.shared_load_a_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load else: assert ctrl_xdlops_mapping.wave_step_m == 2, "currently only support wave_step_m is 2" - fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) + fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ta_c*ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) if ctrl_xdlops_mapping.wave_step_n == 1: fctrl.shared_load_b_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load else: assert ctrl_xdlops_mapping.wave_step_n == 2, "currently only support wave_step_n is 2" - fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) + fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ta_c*ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) fctrl.move_slice_window_a_functor = move_slice_window_a fctrl.move_slice_window_b_functor = move_slice_window_b + fctrl.move_slice_window_accumule_functor = move_slice_window_acc if self.tunable.nxe != 0 else None # sympol type fctrl.v_a = v.v_a @@ -1691,8 +1653,8 @@ def emit_kernel_epilogue(self): else: a = self.agpr self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, - s.s_out_stride_n0() if ta_nb0 != 1 else None, s.s_out_stride_wo(), - s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_k(), v.v_cur_k(), s.s_block_gtc_ik(), v.v_co_sub_m_index(), v.v_tmp())) + None, s.s_out_stride_wo(), + s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_dim_mr(), v.v_out_inb(), s.s_block_gtc_inb(), v.v_co_sub_m_index(), v.v_tmp())) self._emit_front(f"{self.label_out}:") diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 85e96527..65f0a781 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -53,6 +53,7 @@ def __init__(self): self.shared_load_b_functor = None self.move_slice_window_a_functor = None self.move_slice_window_b_functor = None + self.move_slice_window_accumule_functor = None # symbol type self.v_a = None @@ -94,6 +95,7 @@ def emit(self): f_move_slice_window_a = self.ctrl.move_slice_window_a_functor f_move_slice_window_b = self.ctrl.move_slice_window_b_functor + f_move_slice_window_acc = self.ctrl.move_slice_window_accumule_functor v_a = self.ctrl.v_a v_b = self.ctrl.v_b @@ -131,7 +133,7 @@ def mapped_ioffset(i_k, width_byte, pad_pixel, offset = 0): k_pack = self.ctrl.lds_k_pack i_k0 = i_k // k_pack i_kp = i_k % k_pack - return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * k_pack * data_byte + offset + return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * data_byte + offset * k_pack # mi = mapped_ioffset def mi_m(i_k, offset = 0): @@ -173,6 +175,8 @@ def mfma_loop_repeat_1x1_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -199,8 +203,8 @@ def do_unroll_k_1x1_sub(): self._emit(f's_waitcnt lgkmcnt(2)') self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_m * k_per_inst - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_n * k_per_inst + self._emit(f_sld_a(v_a(local_buffer_m), v_sld_a_os(), lds_base_m + mi_m((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_m * k_per_inst + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst))) # (2*i_k+3) * lds_width_n * k_per_inst do_unroll_k_1x1_sub() self._emit(f_move_slice_window_b()) @@ -210,6 +214,8 @@ def do_unroll_k_1x1_sub(): self._emit_empty_line() + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -312,6 +318,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -339,6 +347,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1)) else: mbb_list_last = [create_machine_basic_block(do_interleave_unroll_k_last(), group_mbb_by_end_of_inst_op="v_mfma"), @@ -346,6 +356,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1)) # Label: finishing of fma body @@ -380,6 +392,8 @@ def mfma_loop_repeat_2x2_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) #self._emit(f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") #self._emit(f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") @@ -478,6 +492,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -657,14 +673,14 @@ def do_interleave_unroll_k_sub(): # 3rd fma self._emit(f's_waitcnt lgkmcnt(5)') self._emit(mfma_step_mxn(1, 0, 1, 1)) - self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + (2*i_k+3) * lds_width_n * k_per_inst) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") + self._emit(f_sld_b(v_b(local_buffer_n), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {0}") self._emit_empty_line() # 4th fma self._emit(mfma_step_mxn(1, 1, 1, 1)) self._emit(f_sld_b(v_b(local_buffer_n + repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n((2*i_k+3) * k_per_inst, lds_width_n//2)) + f" ; load i_k:{2*i_k+3} into local buffer {1}, repeat {1}") # (2*i_k+3) * lds_width_n * k_per_inst + lds_width_n//2 if i_k == unroll_k_sub - 1: - self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + (unroll_k // k_per_inst - 1) * lds_width_m * k_per_inst + lds_width_m // 2) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") + self._emit(f_sld_a(v_a(local_buffer_m + repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m((unroll_k // k_per_inst - 1) * k_per_inst, lds_width_m//2)) + f" ; load i_k:{unroll_k // k_per_inst - 1} into local buffer {1}, repeat {1}") self._emit_empty_line() return self._get_deferred() @@ -735,6 +751,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -757,6 +775,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -765,6 +785,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -837,6 +859,8 @@ def mfma_loop_repeat_2x2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") self._emit(f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") @@ -901,6 +925,8 @@ def mfma_loop_repeat_2x2(): self._emit(mfma_step_mxn(0, 1)) # wait global and store to LDS + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") self._emit(f_sst_b()) self._emit(f"s_waitcnt vmcnt(0)") @@ -946,8 +972,8 @@ def mfma_loop_repeat_2x2(): self._emit(f_sld_a(v_a(), v_sld_a_os(), lds_base_m)) self._emit(f_sld_b(v_b(), v_sld_b_os(), lds_base_n)) - self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + lds_width_n // 2 )) - self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + lds_width_m // 2 )) + self._emit(f_sld_b(v_b(repeat_n_thread_offset), v_sld_b_os(), lds_base_n + mi_n(0, lds_width_n // 2 ))) + self._emit(f_sld_a(v_a(repeat_m_thread_offset), v_sld_a_os(), lds_base_m + mi_m(0, lds_width_m // 2 ))) # self._emit(f".itr_k = 0") # self._emit(f".rept {unroll_k // k_per_inst - 1}") @@ -1009,6 +1035,8 @@ def mfma_loop_repeat_2x1_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1096,6 +1124,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -1285,6 +1315,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1307,6 +1339,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -1315,6 +1349,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -1364,6 +1400,8 @@ def mfma_loop_repeat_1x2_lp2(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1451,6 +1489,8 @@ def do_unroll_k_sub(): self._emit_empty_line() # 2nd fma + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f's_waitcnt lgkmcnt(0)') self._emit(f"s_barrier") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") @@ -1637,6 +1677,8 @@ def do_interleave_share_store(): # right after clear acc self._emit(f_move_slice_window_b()) self._emit(f_move_slice_window_a()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") @@ -1659,6 +1701,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(se_sub.lower(interleave_pattern=INTERLEAVE_PTN_0)) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) else: @@ -1667,6 +1711,8 @@ def do_interleave_share_store(): se_last = create_scheduler(self.mc, mbb_list_last) self._emit(do_interleave_gload_and_move_slice_window()) + if f_move_slice_window_acc != None: + self._emit(f_move_slice_window_acc()) mbb_0_mfma_cnt_after_branch_to_start = 2 * cxm.wave_step_m * cxm.wave_step_n - 1 # number of mfma not count into share store interleave slot, check do_interleave_unroll_k_last for last 2 mfma self._emit(se_last.lower(interleave_pattern=INTERLEAVE_PTN_1, mbb_0_mfma_cnt_after_branch_to_start=mbb_0_mfma_cnt_after_branch_to_start)) @@ -1707,7 +1753,7 @@ def do_interleave_share_store(): # start emit - self._emit(f"; start MFMA loop, {cxm.wave_tile_m}x{cxm.wave_tile_n} wave tile with {cxm.wave_repeat_m}x{cxm.wave_repeat_n} repeat, {cxm.wave_step_m}x{cxm.wave_step_n} step") + self._emit(f"; start MFMA loop, {cxm.wave_tile_m}x{cxm.wave_tile_n} wave tile with {cxm.wave_repeat_m}x{cxm.wave_repeat_n} repeat, {cxm.wave_step_m}x{cxm.wave_step_n} step, k_pack:{self.ctrl.lds_k_pack}") self._emit(f"s_waitcnt vmcnt({f_gld_a.get_issues()})") self._emit(f_sst_b()) diff --git a/igemm/algo/shared_memory.py b/igemm/algo/shared_memory.py index 0d95c97b..fcda13c6 100644 --- a/igemm/algo/shared_memory.py +++ b/igemm/algo/shared_memory.py @@ -897,10 +897,9 @@ def expr(self): if ctrl.length_d0 == 1 or ctrl.length_d1 == 1: # this is indeed a 2d case. - + ds_write = inst_ds_write_t(ctrl.length_dp * data_byte) if ctrl.length_d0 == 1 and ctrl.length_d1 == 1: # further, 1d case - ds_write = inst_ds_write_t(ctrl.length_dp * data_byte) self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}')) issue_cnt += ds_write.get_issues() @@ -915,7 +914,7 @@ def expr(self): else: # nhwc almost all case goes here for i_d in range(length_d): - self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{i_d*ctrl.length_dp}', i_d * ctrl.stride_d)) + self._emit(ds_write(f'{self.v_sst_os()}', f'{self.v_src()}+{i_d*ctrl.length_dp}', i_d * stride_d)) issue_cnt += ds_write.get_issues() else: assert False, "un implemented yet" diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index 728d0259..a9f7a5a7 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -369,7 +369,7 @@ def __init__(self, mc, ctrl): mc_base_t.__init__(self, mc) assert type(ctrl) is ctrl_xdlops_mapping_t self.ctrl = ctrl - def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4): + def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4, k_pack = 1): ''' notice! this is to calculate LDS offset for A/B matrix input, it is not the same as C matrix output layout, due to xdlops C matrix output describe is in coalescint_store @@ -379,32 +379,44 @@ def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp #print(f"ctrl.block_n_per_wave()={ctrl.block_n_per_wave()}, ctrl.block_m_per_wave()={ctrl.block_m_per_wave()}") assert ctrl.block_n() == ctrl.block_m() and ctrl.block_k() * ctrl.block_n() * ctrl.block_n_per_wave() * ctrl.block_m_per_wave() == AMDGPU_WAVE_SIZE with self._deferred_context(): - self._emit(f"; xdlops mapping, get source matrix gemm index") + self._emit(f"; xdlops mapping, get source matrix gemm index, k_pack:{k_pack}") self._emit(f"v_and_b32 v[{v_gemm_in}], {ctrl.block_n() - 1}, v[{v_thread_id}] ; block_n index ") self._emit(f"v_and_b32 v[{v_gemm_im}], {ctrl.block_m() - 1}, v[{v_thread_id}] ; block_m index ") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v_gemm_in}], {utility_log2(k_pack)}, v[{v_gemm_in}] ; shift left k_pack:{k_pack}") + self._emit(f"v_lshlrev_b32 v[{v_gemm_im}], {utility_log2(k_pack)}, v[{v_gemm_im}] ; shift left k_pack:{k_pack}") + self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_n())}, v[{v_thread_id}]") if ctrl.block_k() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 0], {ctrl.block_k() - 1}, v[{v_thread_id}] ; block_k_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n)}, v[{v_gemm_in}]") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m)}, v[{v_gemm_im}]") + if k_pack != 1: + self._emit(f"v_and_b32 v[{v_tmp4} + 1], {k_pack - 1}, v[{v_tmp4} + 0] ; and k_pack:{k_pack}") + self._emit(f"v_lshrrev_b32 v[{v_tmp4} + 0], {utility_log2(k_pack)}, v[{v_tmp4} + 0] ; shift right k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 1], v[{v_gemm_in}] ; or k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], v[{v_gemm_im}] ; or k_pack:{k_pack}") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") + else: + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m)}, v[{v_gemm_im}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_k())}, v[{v_thread_id}]") pass if ctrl.block_n_per_wave() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 0], {ctrl.block_n_per_wave() - 1}, v[{v_thread_id}] ; block_n_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.block_n())}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.block_n() * k_pack)}, v[{v_gemm_in}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_n_per_wave())}, v[{v_thread_id}]") if ctrl.block_m_per_wave() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 1], {ctrl.block_m_per_wave() - 1}, v[{v_thread_id}] ; block_m_per_wave index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], {utility_log2(ctrl.block_m())}, v[{v_gemm_im}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], {utility_log2(ctrl.block_m() * k_pack)}, v[{v_gemm_im}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.block_m_per_wave())}, v[{v_thread_id}]") if ctrl.waves_per_n() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 2], {ctrl.waves_per_n() - 1}, v[{v_thread_id}] ; waves_per_n index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 2], {utility_log2(ctrl.wave_tile_n * ctrl.wave_step_n)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 2], {utility_log2(ctrl.wave_tile_n * ctrl.wave_step_n * k_pack)}, v[{v_gemm_in}]") self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.waves_per_n())}, v[{v_thread_id}]") if ctrl.waves_per_m() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 3], {ctrl.waves_per_m() - 1}, v[{v_thread_id}] ; waves_per_m index") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 3], {utility_log2(ctrl.wave_tile_m * ctrl.wave_step_m)}, v[{v_gemm_im}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 3], {utility_log2(ctrl.wave_tile_m * ctrl.wave_step_m * k_pack)}, v[{v_gemm_im}]") # self._emit(f"v_lshrrev_b32 v[{v_thread_id}], {utility_log2(ctrl.waves_per_n())}, v[{v_thread_id}]") self._emit_empty_line() return self._get_deferred() diff --git a/igemm/codegen/macro.py b/igemm/codegen/macro.py index 81163c2c..ec057c9c 100644 --- a/igemm/codegen/macro.py +++ b/igemm/codegen/macro.py @@ -36,6 +36,7 @@ def __init__(self, mc, inline = False): mc_base_t.__init__(self, mc) self.arg_list = list() self.inline = inline + self.expr_cnt = 0 def name(self): return 'n/a macro' def is_inline(self): @@ -71,10 +72,13 @@ def __call__(self, *args): setattr(self, self.arg_list[i], sym_t(args[i])) elif type(args[i]) is sym_t: setattr(self, self.arg_list[i], args[i]) + elif args[i] is None: + setattr(self, self.arg_list[i], None) # 2nd, do the emit with self._deferred_context(): self.expr() + self.expr_cnt += 1 # last, restore arg to default value. for a in self.arg_list: @@ -90,3 +94,4 @@ def emit(self): if not self.is_inline(): with self._emit_macro_indented(".macro {} {}".format(self.name(), ' '.join(self.arg_list))): self.expr() + self.expr_cnt += 1 diff --git a/igemm/codegen/mbb.py b/igemm/codegen/mbb.py index a3f83473..2c778917 100644 --- a/igemm/codegen/mbb.py +++ b/igemm/codegen/mbb.py @@ -183,6 +183,25 @@ def is_mbb_start_cmp_and_exec_block(self, current_index, istrs_list): return True return False return False + + def is_mbb_start_bfe_and_cmpx_block(self, current_index, istrs_list): + assert type(istrs_list) is list + current_istr = istrs_list[current_index] + # current_mc_inst = create_mc_inst(current_istr) + current_inst_op = get_mc_inst_op(current_istr) + if current_inst_op.startswith('v_bfe_u32'): + #print('asdadds XXXXX') + for next_index in range(current_index+1, len(istrs_list)): + next_istr = istrs_list[next_index] + next_mc_inst = create_mc_inst(next_istr) + next_inst_op = get_mc_inst_op(next_istr) + #print(f' next_inst_op:{next_inst_op} ') + if not next_mc_inst: + continue + if next_inst_op.startswith('v_cmp'): + return True + return False + return False def is_mbb_start(self, istr): _istr = istr.strip() @@ -265,7 +284,8 @@ def match_group_mbb_by_end_of_inst_op(inst_op): mc_inst_buffer.append(mc_inst) else: if state == self.STATE_NORMAL: - if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs): + if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \ + or self.is_mbb_start_bfe_and_cmpx_block(i, istrs): mc_inst_buffer.clear() mc_inst_buffer.append(mc_inst) state = self.STATE_PARSING_MBB @@ -273,7 +293,12 @@ def match_group_mbb_by_end_of_inst_op(inst_op): mbbs.append(machine_basic_block_t(copy.copy([mc_inst]))) else: if self.is_mbb_start(istr): - assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}' + assert i > 1 + if self.is_mbb_start_bfe_and_cmpx_block(i - 1, istrs): + # TODO: this require bfe and cmpx have no other lines in between + pass + else: + assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}' if self.is_mbb_end(istr): mc_inst_buffer.append(mc_inst) mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) From d23bee55bdd80edc90ce21e172bf9f6841450ef9 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 28 Jan 2021 17:27:23 +0800 Subject: [PATCH 12/40] fix a bug in fwd validation --- config/igemm_fwd_gtc_gfx908_nhwc.config | 6 ++--- driver/conv_driver.cpp | 2 +- driver/igemm_fwd_gtc_driver.h | 36 ++++++++++++++----------- igemm/algo/igemm_base.py | 11 +++++--- igemm/algo/igemm_fwd_gtc_nhwc.py | 29 ++++---------------- 5 files changed, 38 insertions(+), 46 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 1aa2093d..1c71f606 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -21,7 +21,7 @@ tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' -nxb = 4 +nxb = 0 nxe = 0 @@ -44,7 +44,7 @@ tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' -nxb = 4 +nxb = 0 nxe = 0 #--------------------------- 256x128 @@ -65,5 +65,5 @@ tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' -nxb = 4 +nxb = 0 nxe = 1 diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 93244774..f79570cc 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -511,7 +511,7 @@ int main(int argc, char **argv) { fflush(stdout); if (need_verify) - HIP_CALL(hipMemset(device_output, 0, n * c * ho * wo * sizeof(float))); + HIP_CALL(hipMemset(device_output, 0, n * k * ho * wo * sizeof(float))); result_t result = conv_fwd_driver.run(&conv_args, tunable, module, device_input, diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index fb5106af..602d6e04 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -146,7 +146,9 @@ class igemm_fwd_gtc_t { int gemm_n_per_block = tunable->gemm_n_per_block; int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + int b = ho * wo; + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 int gemm_m = 0; int gemm_n = 0; @@ -196,7 +198,9 @@ class igemm_fwd_gtc_t { int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + int b = ho * wo; + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 bool unit_conv = (x==1)&&(y==1)&&(stride_h==1)&&(stride_w==1)&&(dilation_h==1)&&(dilation_w==1)&&(pad_h==0)&&(pad_w==0); @@ -263,19 +267,19 @@ class igemm_fwd_gtc_t { return false; } - if(gemm_m_per_block % tunable->nxb != 0){ - //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } + // if(gemm_m_per_block % tunable->nxb != 0){ + // //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + // return false; + // } - if(n % (gemm_m_per_block / tunable->nxb) != 0){ - //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); - return false; - } + // if(n % (gemm_m_per_block / tunable->nxb) != 0){ + // //printf("tunable_is_valid false: n%(gemm_n_per_block/tunable->nxb)!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); + // return false; + // } - if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ - return false; - } + // if((nxe == 0) && ((b % tunable->nxb != 0) || (gemm_k % gemm_k_per_block != 0))){ + // return false; + // } if((nxe == 0) && !unit_conv){ return false; @@ -344,8 +348,10 @@ class igemm_fwd_gtc_t { int gemm_k_per_block = tunable->gemm_k_per_block; int nxe = tunable->nxe; int nxb = tunable->nxb; - int b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - + int b = ho * wo; + if(tunable->tensor_layout == "nchw") + b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 + igemm_fwd_gtc_karg_t karg; size_t karg_size = sizeof(karg); karg.p_in = p_in; diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index f2f3d6cc..0e3e166f 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -201,7 +201,12 @@ def __init__(self, tunable_dict): # assert type(self.opt_1x1) is bool assert self.direction in ('fwd', 'bwd', 'wrw') assert self.precision in ('fp32', 'fp16', 'bf16') - assert self.nxb in (1,4,8,16,32,64,128,256) + if self.tensor_layout == "nchw": + assert self.nxb in (1,4,8,16,32,64,128,256) + elif self.tensor_layout == "nhwc": + assert self.nxb == 0, 'nhwc now no need have different nxb value' + else: + assert False assert self.nxe in (0,1) # TODO: better specify @@ -226,13 +231,13 @@ def _unmerge_x1_from_e(unroll_k, nxe): return unroll_k # not used if self.direction == 'fwd': - assert self.gemm_n_per_block % self.nxb == 0 if self.tensor_layout == 'nchw': + assert self.gemm_n_per_block % self.nxb == 0 self.unmerge_sub_n = self.gemm_n_per_block // self.nxb self.unmerge_sub_k = 1 # not used self.unmerge_sub_c = _unmerge_x1_from_e(self.gemm_k_per_block, self.nxe) elif self.tensor_layout == 'nhwc': - self.unmerge_sub_n = self.gemm_m_per_block // self.nxb + self.unmerge_sub_n = 1 # not used self.unmerge_sub_k = 1 # not used self.unmerge_sub_c = 1 # not used else: diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index a2f23998..db7d953f 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -468,19 +468,10 @@ def __init__(self, mc, outer): self.s_block_gtc_ik = sym_t("s_block_gtc_ik" , sseq(1)) self.s_block_gtc_inb = sym_t("s_block_gtc_inb" , sseq(1)) - self.s_move_slice_k_c1e = sym_t("s_move_slice_k_c1e" , sseq(1)) - if outer.tunable.nxe != 0: - self.s_move_slice_k_c = sym_t("s_move_slice_k_c" , sseq(1)) - self.s_move_slice_k_y = sym_t("s_move_slice_k_y" , sseq(1)) - self.s_move_slice_k_x = sym_t("s_move_slice_k_x" , self.s_block_gtc_ig.value) - self.s_move_slice_k_stride_c = sym_t("s_move_slice_k_stride_c" , sseq(1)) self.s_knum = sym_t("s_knum" , 3) self.s_gemm_k_num_c = sym_t("s_gemm_k_num_c" , sseq(1)) - if outer.tunable.nxe != 0: - self.s_gemm_k_num_y = sym_t("s_gemm_k_num_y" , self.s_y.value) - self.s_gemm_k_num_x = sym_t("s_gemm_k_num_x" , self.s_x.value) #if outer.tunable.nxe != 0: self.s_dim_br = sym_t("s_dim_br" , sseq(1)) @@ -488,14 +479,6 @@ def __init__(self, mc, outer): self.s_dim_mr = sym_t("s_dim_mr" , sseq(1)) self.s_dim_np = sym_t("s_dim_np" , sseq(1)) - # self.s_len_h = sym_t("s_len_h" , sseq(1)) - # self.s_len_w = sym_t("s_len_w" , sseq(1)) - # self.s_lim_h = sym_t("s_lim_h" , sseq(1)) - # self.s_lim_w = sym_t("s_lim_w" , sseq(1)) - # self.s_thread_stride_w = sym_t("s_thread_stride_w" , sseq(1)) - # self.s_thread_stride_h = sym_t("s_thread_stride_h" , sseq(1)) - # self.s_thread_stride_n = sym_t("s_thread_stride_n" , sseq(1)) - self.s_in_diff_hi = sym_t("s_in_diff_hi" , sseq(1)) self.s_in_diff_wi = sym_t("s_in_diff_wi" , sseq(1)) self.s_dilation_w_x = sym_t("s_dilation_w_x" , sseq(1)) @@ -1420,16 +1403,14 @@ def emit_kernel_prologue(self): self._emit(f"v_lshlrev_b32 v[{v.v_tmp()}], {igemm_log2(data_byte)}, v[{v.v_co_sub_n_index()}]") self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") - - self._emit(f"; add i_k") self._emit(f"; move slice stride") self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") - if self.tunable.nxe != 0: - self._emit(f"s_mov_b32 s[{s.s_tmp()}], {na_c}") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_tmp()}], {igemm_log2(data_byte)}") - else: - self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") + # if self.tunable.nxe != 0: + # self._emit(f"s_mov_b32 s[{s.s_tmp()}], {na_c}") + # self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_tmp()}], {igemm_log2(data_byte)}") + # else: + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") if self.tunable.nxe != 0: # s_in_diff_wi : s_dilation_w * s_in_stride_wi From 2647213865f5d0a54d24aa886560f52b7dec7e25 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 28 Jan 2021 18:53:42 +0800 Subject: [PATCH 13/40] fix a bug in non 1x1 case --- config/igemm_fwd_gtc_gfx908_nhwc.config | 67 ++++++++++++++++ igemm/algo/igemm_fwd_gtc_nhwc.py | 100 +----------------------- 2 files changed, 71 insertions(+), 96 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 1c71f606..bc851106 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -67,3 +67,70 @@ precision = "fp32" tensor_layout = 'nhwc' nxb = 0 nxe = 1 + + + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 8 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 2, 1, 128] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 2, 1, 128] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 128x128 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 128x128 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 \ No newline at end of file diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index db7d953f..e6e4050a 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -207,7 +207,7 @@ def expr(self): self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") - self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 0, 1") + self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 1, 0") self._emit_empty_line() class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): @@ -268,12 +268,12 @@ def expr(self): self._emit(f"s_cmp_le_u32 s[{self.s_x()}], s[{self.s_move_slice_k_ix()}]") # update iwi - self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_dilation_w()}], s[{self.s_dilation_w_x()}]") + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_dilation_w_x()}], s[{self.s_dilation_w()}]") for i in range(nb_per_thread): self._emit(f"v_add_u32 v[{self.v_in_iwi_list(i)}], s[{self.s_tmp()}], v[{self.v_in_iwi_list(i)}]") # update in_os - self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_in_diff_wi()}], s[{self.s_in_diff_hi()}]") + self._emit(f"s_cselect_b32 s[{self.s_tmp()}], s[{self.s_in_diff_hi()}], s[{self.s_in_diff_wi()}]") for i in range(nb_per_thread): self._emit(f"v_add_u32 v[{self.v_in_os(i)}], s[{self.s_tmp()}], v[{self.v_in_os(i)}]") @@ -1157,79 +1157,8 @@ def emit_kernel_prologue(self): self._emit(m_set_flag_nhw(v.v_in_flag(0), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) self._emit_empty_line() - # if self.tunable.nxe != 0: - # self._emit(f"s_mul_i32 s[{s.s_len_h()}], s[{s.s_ho()}], s[{s.s_stride_h()}]") - # self._emit(f"s_sub_i32 s[{s.s_lim_h()}], s[{s.s_len_h()}], s[{s.s_pad_h()}]") - # self._emit(f"s_mul_i32 s[{s.s_len_w()}], s[{s.s_wo()}], s[{s.s_stride_w()}]") - # self._emit(f"s_sub_i32 s[{s.s_lim_w()}], s[{s.s_len_w()}], s[{s.s_pad_w()}]") - # voffset if ta_nb0 != 1 or ta_nb1 != 1: - ''' - thread_stride = na_nb1 if ta_nb0 != 1 else 1 - self._emit(f"s_mov_b32 s[{s.s_tmp(5)}], {thread_stride}") - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), s.s_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), s.s_tmp())) - else: - self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_thread_stride_n(), s.s_tmp(5), s.s_dim_br(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_ss(s.s_thread_stride_w(), s.s_thread_stride_h(), s.s_tmp(4), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp(5), v.v_tmp(), s.s_tmp())) - - if self.tunable.nxe != 0: - self._emit(f"s_mul_i32 s[{s.s_thread_stride_h()}], s[{s.s_thread_stride_h()}], s[{s.s_stride_h()}]") - self._emit(f"s_mul_i32 s[{s.s_thread_stride_w()}], s[{s.s_thread_stride_w()}], s[{s.s_stride_w()}]") - - # now let's precompute all the voffset - # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h - # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v[{v.v_in_ihi_list()}]") - self._emit(f"v_mov_b32 v[{v.v_tmp(4)}], v[{v.v_in_iwi_list()}]") - self._emit(f"v_mov_b32 v[{v.v_tmp(3)}], v[{v.v_in_in()}]") - nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 - for i in range(1, nb_per_thread): - # v_tmp+4:ihi, v_tmp+5:iwi - self._emit(f"v_add_i32 v[{v.v_tmp(4)}], s[{s.s_thread_stride_w()}], v[{v.v_tmp(4) }]") - self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_w()}], v[{v.v_tmp(4)}]") - self._emit(f"v_sub_i32 v[{v.v_tmp(4)}], v[{v.v_tmp(4)}], s[{s.s_len_w()}]") - if self.tunable.nxe != 0: - self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_stride_h()}], v[{v.v_tmp(5)}]") - else: - self._emit(f"v_add_i32 v[{v.v_tmp(5)}], 1, v[{v.v_tmp(5)}]") - self._emit(f"s_mov_b64 exec, -1") - - self._emit(f"v_add_i32 v[{v.v_tmp(5)}], s[{s.s_thread_stride_h()}], v[{v.v_tmp(5)}]") - self._emit(f"v_cmpx_le_i32 vcc, s[{s.s_lim_h()}], v[{v.v_tmp(5)}]") - self._emit(f"v_sub_i32 v[{v.v_tmp(5)}], v[{v.v_tmp(5)}], s[{s.s_len_h()}]") - self._emit(f"v_add_u32 v[{v.v_tmp(3)}], 1, v[{v.v_tmp(3)}]") - self._emit(f"s_mov_b64 exec, -1") - - self._emit(f"v_add_u32 v[{v.v_tmp(3)}], s[{s.s_thread_stride_n()}], v[{v.v_tmp(3)}]") - - self._emit(f"v_mov_b32 v[{v.v_in_ihi_list(i)}], v[{v.v_tmp(5)}]") - self._emit(f"v_mov_b32 v[{v.v_in_iwi_list(i)}], v[{v.v_tmp(4)}]") - - if self.tunable.nxe != 0: - # update flag for batch size - self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_tmp(3)}]") - self._emit(f"v_cndmask_b32 v[{v.v_tmp(1)}], 0, 1, vcc") - # extra, store this into flag n - if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: - self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp(1)}], {16 + i}, v[{v.v_in_flag()}]") - self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) - self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], {i}, v[{v.v_in_flag()}]") - else: - self._emit(f"v_lshl_or_b32 v[{v.v_in_flag_n()}], v[{v.v_tmp(1)}], {i}, v[{v.v_in_flag_n()}]") - self._emit(m_set_flag_nhw(v.v_in_flag(i), v.v_tmp(1), v.v_tmp(5), v.v_tmp(4), s.s_hi(), s.s_wi())) - - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_tmp(3)}]") - self._emit(f"v_add_lshl_u32 v[{v.v_tmp(2)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_tmp(5)}]") - self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") - self._emit(f"v_add_u32 v[{v.v_in_os(i)}], v[{v.v_tmp(2)}], v[{v.v_tmp()}]") - ''' thread_stride = na_nb1 if ta_nb0 != 1 else 1 nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 @@ -1380,12 +1309,6 @@ def emit_kernel_prologue(self): self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") - # self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_in0()}], {igemm_log2(unmerge_sub_n1 * data_byte)}") - # self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") - # self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_n()}], s[{s.s_tmp(3)}]") - # self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") - # self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") - self._emit_empty_line() self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], s[{s.s_block_gtc_ik()}], {igemm_log2(data_byte)}") self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp(3)}]") @@ -1393,11 +1316,6 @@ def emit_kernel_prologue(self): self._emit_empty_line() self._emit(self.try_shift_stride(s.s_out_stride_wo, igemm_log2(data_byte))) - # self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_out_stride_wo()}], s[{s.s_block_gtc_inb()}]") - # self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_out_stride_wo()}], s[{s.s_block_gtc_inb()}]") - # self._emit(f"s_add_u32 s[{s.s_p_out(0)}], s[{s.s_p_out(0)}], s[{s.s_tmp(0)}]") - # self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") - # self._emit_empty_line() self._emit(f"v_add_u32 v[{v.v_out_inb()}], s[{s.s_block_gtc_inb()}], v[{v.v_co_sub_m_index()}] ; total n*ho*wo") self._emit(f"v_mul_lo_u32 v[{v.v_out_os()}], s[{s.s_out_stride_wo()}], v[{v.v_out_inb()}]") self._emit(f"v_lshlrev_b32 v[{v.v_tmp()}], {igemm_log2(data_byte)}, v[{v.v_co_sub_n_index()}]") @@ -1418,7 +1336,7 @@ def emit_kernel_prologue(self): # s_dilation_w_x : -1* (x - 1) * s_dilation_w self._emit(f"s_mov_b32 s[{s.s_move_slice_k_ix()}], 0") self._emit(f"s_mul_i32 s[{s.s_in_diff_wi()}], s[{s.s_dilation_w()}], s[{s.s_in_stride_wi()}]") # shifted - self._emit(f"s_mul_i32 s[{s.s_tmp(3)}], s[{s.s_x()}], 1") + self._emit(f"s_sub_i32 s[{s.s_tmp(3)}], s[{s.s_x()}], 1") self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_in_diff_wi()}], s[{s.s_tmp(3)}]") self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_in_stride_wi()}], s[{s.s_wi()}]") self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_tmp(1)}], s[{s.s_dilation_h()}]") @@ -1428,16 +1346,6 @@ def emit_kernel_prologue(self): self._emit_empty_line() - # if self.tunable.nxe != 0: - # # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) - # self._emit(self.try_shift_stride(s.s_wei_stride_k, igemm_log2(data_byte))) - # # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) - # else: - # # self._emit(self.try_shift_stride(s.s_stride_c, igemm_log2(data_byte))) - # self._emit(self.try_shift_stride(s.s_c, igemm_log2(data_byte))) - # # self._emit(self.try_shift_stride(s.s_out_stride_k, igemm_log2(data_byte))) - - # self._emit(self.try_shift_stride(s.s_move_slice_k_c1e, igemm_log2(data_byte))) self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") From 801261f6ce9f4db161b15983441f2ab1bb5264f0 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 28 Jan 2021 20:13:12 +0800 Subject: [PATCH 14/40] add magic div --- config/igemm_fwd_gtc_gfx908_nhwc.config | 24 +++- driver/igemm_fwd_gtc_driver.h | 169 +++++++++++++++++------- igemm/algo/igemm_fwd_gtc_nhwc.py | 102 ++++++-------- 3 files changed, 182 insertions(+), 113 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index bc851106..e24c3ea6 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -133,4 +133,26 @@ direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 -nxe = 1 \ No newline at end of file +nxe = 1 + +# #--------------------------- 128x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 128 +# gemm_k_per_block = 32 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 \ No newline at end of file diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 602d6e04..d290824b 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -71,6 +71,38 @@ typedef struct { #endif } __attribute__((packed)) igemm_fwd_gtc_karg_t; +typedef struct { + float *p_in; + float *p_wei; + float *p_out; + int hi; + int wi; + int n; + int k; // this is indeed k_per_group + int c; // this is indeed c_per_group + int ho; + int wo; + int stride_h; + int stride_w; + int dilation_h; + int dilation_w; + int pad_h; + int pad_w; + int y; + int x; + int group; +#if USE_MAGIC_DIV + uint32_t magic_0; // denom: gemm_n / n_per_block + uint32_t magic_1; // denom: ho*wo + uint32_t magic_2; // denom: wo + uint32_t magic_3; // denom: (gemm_m/m_per_block) * (gemm_n/n_per_block) + uint32_t shift_pack_0; + uint32_t __pack_0; +#endif +} __attribute__((packed)) igemm_fwd_gtc_nhwc_karg_t; + +#define IGEMM_FWD_GTC_MAX_KARG_SIZE 160 + static void dump_fwd_karg(igemm_fwd_gtc_karg_t * karg){ std::cout<<"p_in:" <p_in<<","; std::cout<<"p_wei:" <p_wei<<","; @@ -352,62 +384,105 @@ class igemm_fwd_gtc_t { if(tunable->tensor_layout == "nchw") b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 - igemm_fwd_gtc_karg_t karg; - size_t karg_size = sizeof(karg); - karg.p_in = p_in; - karg.p_wei = p_wei; - karg.p_out = p_out; - karg.hi = hi; - karg.wi = wi; - karg.n = n; - karg.k = k / group; - karg.c = c / group; - karg.ho = ho; - karg.wo = wo; - - karg.stride_h = stride_h; - karg.stride_w = stride_w; - karg.dilation_h = dilation_h; - karg.dilation_w = dilation_w; - karg.pad_h = pad_h; - karg.pad_w = pad_w; - karg.y = y; - karg.x = x; - karg.group = group; + size_t karg_size = 0; + uint8_t karg_buffer[IGEMM_FWD_GTC_MAX_KARG_SIZE]; + if(tunable->tensor_layout == "nchw"){ + igemm_fwd_gtc_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = hi; + karg.wi = wi; + karg.n = n; + karg.k = k / group; + karg.c = c / group; + karg.ho = ho; + karg.wo = wo; + karg.stride_h = stride_h; + karg.stride_w = stride_w; + karg.dilation_h = dilation_h; + karg.dilation_w = dilation_w; + karg.pad_h = pad_h; + karg.pad_w = pad_w; + karg.y = y; + karg.x = x; + karg.group = group; #if USE_MAGIC_DIV - int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; - int gemm_n = n * b; - { - // init magic division parameters - uint32_t nb_n0 = tunable->tensor_b_cluster_lengths[2] * tunable->tensor_b_thread_lengths[2]; - uint32_t nb_n1b = tunable->tensor_b_cluster_lengths[3] * tunable->tensor_b_thread_lengths[3]; - uint32_t unmerge_sub_n = gemm_n_per_block / nxb; - uint32_t unmerge_sub_n1 = tunable->gemm_n_unmerge_cluster == 0 ? unmerge_sub_n / nb_n0 : unmerge_sub_n; - - magic_div_u32_t mdiv_0 = magic_div_u32_gen(tunable->source_access_order == 0 ? ((n * b) / gemm_n_per_block) : ((gemm_m) / gemm_m_per_block)); - magic_div_u32_t mdiv_1 = magic_div_u32_gen(tunable->gemm_n_unmerge_cluster == 0 ? - b * unmerge_sub_n1 / nb_n1b : - (n / nb_n0) * b / nb_n1b ); - magic_div_u32_t mdiv_2 = magic_div_u32_gen(y * x); - magic_div_u32_t mdiv_3 = magic_div_u32_gen(x); - magic_div_u32_t mdiv_4 = magic_div_u32_gen(b); - magic_div_u32_t mdiv_5 = magic_div_u32_gen(wo); - magic_div_u32_t mdiv_6 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * - utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); + int gemm_m = ((k/group + gemm_m_per_block -1)/gemm_m_per_block) * gemm_m_per_block; + int gemm_n = n * b; + { + // init magic division parameters + uint32_t nb_n0 = tunable->tensor_b_cluster_lengths[2] * tunable->tensor_b_thread_lengths[2]; + uint32_t nb_n1b = tunable->tensor_b_cluster_lengths[3] * tunable->tensor_b_thread_lengths[3]; + uint32_t unmerge_sub_n = gemm_n_per_block / nxb; + uint32_t unmerge_sub_n1 = tunable->gemm_n_unmerge_cluster == 0 ? unmerge_sub_n / nb_n0 : unmerge_sub_n; + + magic_div_u32_t mdiv_0 = magic_div_u32_gen(tunable->source_access_order == 0 ? ((n * b) / gemm_n_per_block) : ((gemm_m) / gemm_m_per_block)); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(tunable->gemm_n_unmerge_cluster == 0 ? + b * unmerge_sub_n1 / nb_n1b : + (n / nb_n0) * b / nb_n1b ); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(y * x); + magic_div_u32_t mdiv_3 = magic_div_u32_gen(x); + magic_div_u32_t mdiv_4 = magic_div_u32_gen(b); + magic_div_u32_t mdiv_5 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_6 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * + utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); + + karg.magic_0 = mdiv_0.magic; + karg.magic_1 = mdiv_1.magic; + karg.magic_2 = mdiv_2.magic; + karg.magic_3 = mdiv_3.magic; + karg.magic_4 = mdiv_4.magic; + karg.magic_5 = mdiv_5.magic; + karg.magic_6 = mdiv_6.magic; + karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); + karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); + } +#endif + karg_size = sizeof(karg); + memcpy(static_cast(&karg_buffer[0]), static_cast(&karg), karg_size); + }else if(tunable->tensor_layout == "nhwc"){ + igemm_fwd_gtc_nhwc_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = hi; + karg.wi = wi; + karg.n = n; + karg.k = k / group; + karg.c = c / group; + karg.ho = ho; + karg.wo = wo; + karg.stride_h = stride_h; + karg.stride_w = stride_w; + karg.dilation_h = dilation_h; + karg.dilation_w = dilation_w; + karg.pad_h = pad_h; + karg.pad_w = pad_w; + karg.y = y; + karg.x = x; + karg.group = group; +#if USE_MAGIC_DIV + int gemm_m = n * ho * wo; + int gemm_n = k / group; + magic_div_u32_t mdiv_0 = magic_div_u32_gen(gemm_n / gemm_n_per_block); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(ho*wo); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_3 = magic_div_u32_gen((gemm_m/gemm_m_per_block) * (gemm_n/gemm_n_per_block)); karg.magic_0 = mdiv_0.magic; karg.magic_1 = mdiv_1.magic; karg.magic_2 = mdiv_2.magic; karg.magic_3 = mdiv_3.magic; - karg.magic_4 = mdiv_4.magic; - karg.magic_5 = mdiv_5.magic; - karg.magic_6 = mdiv_6.magic; karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); - karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); - } #endif + karg_size = sizeof(karg); + memcpy(static_cast(&karg_buffer[0]), static_cast(&karg), karg_size); + } else { + assert(0); + } int block_size = get_block_size(tunable); int grid_size = get_grid_size(arg, tunable); @@ -421,7 +496,7 @@ class igemm_fwd_gtc_t { auto launch_fwd = [&]() -> float { // printf("launch fwd block:%d, grid:%d\n", block_size, grid_size); // dump_fwd_karg(&karg); - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, static_cast(&karg_buffer[0]), HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, HIP_LAUNCH_PARAM_END}; float ms = .0; diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index e6e4050a..8c0b2184 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -407,13 +407,9 @@ def __init__(self, mc, outer): self.k_magic_1 = sym_t('k_magic_1' ,92) self.k_magic_2 = sym_t('k_magic_2' ,96) self.k_magic_3 = sym_t('k_magic_3' ,100) - self.k_magic_4 = sym_t('k_magic_4' ,104) - self.k_magic_5 = sym_t('k_magic_5' ,108) - self.k_magic_6 = sym_t('k_magic_6' ,112) - self.k_shift_pack_0 = sym_t('k_shift_pack_0' ,116) - self.k_shift_pack_1 = sym_t('k_shift_pack_1' ,120) - self.k__pack_0 = sym_t('k__pack_0' ,124) - self.k_end = sym_t('k_end' ,128) + self.k_shift_pack_0 = sym_t('k_shift_pack_0' ,104) + self.k__pack_0 = sym_t('k__pack_0' ,108) + self.k_end = sym_t('k_end' ,112) else: self.k_end = sym_t('k_end' ,88) @@ -491,28 +487,18 @@ def __init__(self, mc, outer): m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() #in_npc = m_in_2d_global_load.get_num_precache_soffset() wei_npc = m_wei_2d_global_load.get_num_precache_soffset() - #self.s_in_offset = sym_t("s_in_offset" ,sseq(in_npc)) # if this number is zero, it is also OK, since we would not use self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) - # self.s_k_padded = sym_t("s_k_padded" ,sseq(1)) # TODO: this sgpr allocation is a mess if IGEMM_GTC_FEAT_MAGIC_DIVISION: # allocate several sgpr to hold magic/shift value. - self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_p_out.value + 2) - self.s_shift_pack_1 = sym_t("s_shift_pack_1" ,self.s_p_out.value + 3) - - self.s_magic_2 = sym_t("s_magic_2" ,self.s_in_stride_c_c1.value) # when load, loadx4 with magic_0/1 - self.s_magic_3 = sym_t("s_magic_3" ,self.s_in_stride_c_c0_c1_diff.value) # when load, loadx4 with magic_0/1 - - self.s_magic_4 = sym_t("s_magic_4" ,self.s_move_slice_k_c1e.value) - self.s_magic_5 = sym_t("s_magic_5" ,self.s_gemm_k_num_c1.value) - self.s_magic_6 = sym_t("s_magic_6" ,self.s_block_gtc_in0.value) + self.s_magic_0 = sym_t("s_magic_0" ,self.s_p_in.value + 2) + self.s_magic_1 = sym_t("s_magic_1" ,self.s_p_in.value + 3) + self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_wei.value + 2) + self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_wei.value + 3) + self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_flag_need_acc_yx.value) self.s_tmp = sym_t("s_tmp" ,sseq(6, 2)) - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self.s_magic_0 = sym_t("s_magic_0" ,self.s_p_wei.value + 2) - self.s_magic_1 = sym_t("s_magic_1" ,self.s_p_wei.value + 3) - self.s_end = sym_t("s_end" ,sseq()) def get_count(self): @@ -574,8 +560,6 @@ def __init__(self, mc, outer): self.v_in_inb = sym_t("v_in_inb" ,vseq(1)) self.v_in_in = sym_t("v_in_in" ,vseq(1)) - # if tb_k0 != 1: - # self.v_wei_ik0 = sym_t("v_wei_ik0" ,vseq(1)) self.v_wei_ik = sym_t("v_wei_ik" ,vseq(1)) self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) @@ -585,13 +569,13 @@ def __init__(self, mc, outer): self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) self.v_out_inb = sym_t("v_out_inb" ,vseq(1)) - self.v_gemm_in = sym_t("v_gemm_in" , vseq(1)) - self.v_gemm_im = sym_t("v_gemm_im" , vseq(1)) + self.v_gemm_in = sym_t("v_gemm_in" ,vseq(1)) + self.v_gemm_im = sym_t("v_gemm_im" ,vseq(1)) - self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) - self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) + self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) + self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) - self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) total_vgpr = vseq() if outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if xdlops agpr is larger than vgpr usage, must change vgpr count to agpr @@ -895,12 +879,8 @@ def get_kernel_args(self): kas.append(amdgpu_kernel_arg_t('magic_1' , 4, 92, 'by_value','i32')) kas.append(amdgpu_kernel_arg_t('magic_2' , 4, 96, 'by_value','i32')) kas.append(amdgpu_kernel_arg_t('magic_3' , 4, 100, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('magic_4' , 4, 104, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('magic_5' , 4, 108, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('magic_6' , 4, 112, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('shift_pack_0' , 4, 116, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('shift_pack_1' , 4, 120, 'by_value','i32')) - kas.append(amdgpu_kernel_arg_t('__pack_0' , 4, 124, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('shift_pack_0' , 4, 104, 'by_value','i32')) + kas.append(amdgpu_kernel_arg_t('__pack_0' , 4, 108, 'by_value','i32')) else: pass return kas @@ -983,10 +963,8 @@ def emit_kernel_prologue(self): if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_load_dwordx2 s[{s.s_magic_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_0()}") - self._emit(f"s_load_dwordx2 s[{s.s_tmp((2, 3))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") - self._emit(f"s_load_dwordx2 s[{s.s_tmp((4, 5))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_4()}") - self._emit(f"s_load_dword s[{s.s_magic_6()}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_6()}") - self._emit(f"s_load_dwordx2 s[{s.s_shift_pack_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") + self._emit(f"s_load_dwordx2 s[{s.s_magic_2((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") + self._emit(f"s_load_dword s[{s.s_shift_pack_0()}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") @@ -999,18 +977,10 @@ def emit_kernel_prologue(self): self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) self._emit_empty_line() - self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") - self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") self._emit(f"s_waitcnt lgkmcnt(0)") self._emit_empty_line() - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_mov_b32 s[{s.s_magic_2()}], s[{s.s_tmp(2)}]") - self._emit(f"s_mov_b32 s[{s.s_magic_3()}], s[{s.s_tmp(3)}]") - self._emit(f"s_mov_b32 s[{s.s_magic_4()}], s[{s.s_tmp(4)}]") - self._emit(f"s_mov_b32 s[{s.s_magic_5()}], s[{s.s_tmp(5)}]") self._emit(f"; calculate index") - # calculate stride, not shift data byte yet # input self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") @@ -1060,8 +1030,8 @@ def emit_kernel_prologue(self): self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080010 ; offset:16, width:8") - self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), s.s_magic_6(), s.s_tmp(3), '0', s.s_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") + self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), s.s_magic_3(), s.s_tmp(3), '0', s.s_tmp())) else: self._emit(m_int_div_rem_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), '0', v.v_tmp(5), v.v_tmp(), s.s_tmp())) @@ -1092,10 +1062,10 @@ def emit_kernel_prologue(self): self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_in_inb()}]") if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) @@ -1110,10 +1080,10 @@ def emit_kernel_prologue(self): else: if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wi(), v.v_tmp())) else: self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) @@ -1168,10 +1138,10 @@ def emit_kernel_prologue(self): self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_tmp()}]") if self.tunable.nxe != 0: if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wo(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo(), v.v_tmp())) else: self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) @@ -1186,10 +1156,10 @@ def emit_kernel_prologue(self): else: if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_4(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_5(), s.s_tmp(3), s.s_wi(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wi(), v.v_tmp())) else: self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) @@ -1217,6 +1187,8 @@ def emit_kernel_prologue(self): pass # load in + self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") self._emit(self.global_load_in()) self._emit_empty_line() self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") From 8afe37a2b29428e01644ad0ad5eddf134b5a0a09 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 29 Jan 2021 16:06:59 +0800 Subject: [PATCH 15/40] fwd support split batch (#78) * fwd support split batch * remove confusing assert * fix several >4G address type --- driver/conv_driver.cpp | 30 +++++++++++----- driver/igemm_fwd_gtc_driver.h | 67 +++++++++++++++++++++++++++++++++-- igemm/algo/igemm_fwd_gtc.py | 20 ++++++++++- 3 files changed, 104 insertions(+), 13 deletions(-) diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index f79570cc..b9aa9815 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -165,6 +165,8 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi, #define IGEMM_CONFIG_FILE "igemm_gtc.config" #endif +#define IGEMM_RUN_ONLY_KERNEL_DEFAULT "off" + #define WARMUP 3 #define REPEAT 8 #define SCLK_MHZ 1283 @@ -214,14 +216,14 @@ struct distribution_t{ }; template -void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale) +void block_wise_rand_generator(Dst_T *p, int tid, int block_size, size_t total_size, Src_T min, Src_T max, Src_T scale) { std::mt19937 rng(std::chrono::system_clock::now() .time_since_epoch() .count() + std::hash()(std::this_thread::get_id())); distribution_t distribution(min,max); - for (int i = tid; i < total_size; i += block_size) { + for (size_t i = tid; i < total_size; i += block_size) { p[i] = static_cast(scale * distribution(rng)); } } @@ -342,6 +344,7 @@ void dump_arg(const args_t *arg) { int main(int argc, char **argv) { char *hsaco = env_get_str("IGEMM_HSACO", IGEMM_HSACO); char *config_file = env_get_str("IGEMM_CONFIG_FILE", IGEMM_CONFIG_FILE); + std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT); int warmup = env_get_int("IGEMM_WARMUP", WARMUP); int repeat = env_get_int("IGEMM_REPEAT", REPEAT); int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ); @@ -457,8 +460,8 @@ int main(int argc, char **argv) { gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); - //gen_rand_vector(host_input, n * c * hi * wi, 1, 1); - //gen_rand_vector(host_weight, k * c * y * x, 1, 1); + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 1, 1); + //gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); #ifdef USE_GPU_NAIVE_CONV HIP_CALL(hipMemcpy(device_input, host_input, @@ -506,6 +509,9 @@ int main(int argc, char **argv) { double nrms = get_fwd_nrms(); for (int i = 0; i < tunables.size(); i++) { igemm_gtc_tunable_t *tunable = &tunables[i]; + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) + if(run_only_kernel != conv_fwd_driver.get_kernel_name(tunable)) + continue; printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str()); fflush(stdout); @@ -569,8 +575,8 @@ int main(int argc, char **argv) { gen_rand_vector(host_output, static_cast(n) * k * ho * wo, 0.0, 1.0); gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number - // gen_rand_vector(host_output, n * k * ho * wo,1, 1); - // gen_rand_vector(host_weight, k * c * y * x, 1, 1); + // gen_rand_vector(host_output, static_cast(n) * k * ho * wo,1, 1); + // gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); #ifdef USE_GPU_NAIVE_CONV HIP_CALL(hipMemcpy(device_output, host_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); @@ -618,6 +624,9 @@ int main(int argc, char **argv) { double nrms = get_bwd_nrms(); for (int i = 0; i < tunables.size(); i++) { igemm_gtc_tunable_t *tunable = &tunables[i]; + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) + if(run_only_kernel != conv_bwd_driver.get_kernel_name(tunable)) + continue; printf("[bwd:%2d] %s, ", i, conv_bwd_driver.get_kernel_name(tunable).c_str()); fflush(stdout); @@ -680,8 +689,8 @@ int main(int argc, char **argv) { // gen rand gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); gen_rand_vector(host_output, static_cast(n) * k * ho * wo, -0.5, 0.5); - //gen_rand_vector(host_input, n * k * hi * wi, -5, 5); - //gen_rand_vector(host_output, n * k * ho * wo, 1, 1); + //gen_rand_vector(host_input, static_cast(n) * k * hi * wi, -5, 5); + //gen_rand_vector(host_output, static_cast(n) * k * ho * wo, 1, 1); #ifdef USE_GPU_NAIVE_CONV HIP_CALL(hipMemcpy(device_input, host_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); @@ -763,13 +772,16 @@ int main(int argc, char **argv) { for (int i = 0; i < tunables.size(); i++) { igemm_gtc_tunable_t *tunable = &tunables[i]; + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) + if(run_only_kernel != conv_wrw_driver.get_kernel_name(tunable)) + continue; printf("[wrw:%2d] %s, ", i, conv_wrw_driver.get_kernel_name(tunable).c_str()); fflush(stdout); if (need_verify) HIP_CALL(hipMemset(device_weight, 0, - k * c * y * x * sizeof(float))); + static_cast(k) * c * y * x * sizeof(float))); result_t result = conv_wrw_driver.run(&conv_args, tunable, module, device_input, device_weight, device_output, warmup, repeat); diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index d290824b..250306b0 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -174,6 +174,9 @@ class igemm_fwd_gtc_t { int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); + int splits = split_batch_size(arg, tunable); + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int nxe = tunable->nxe; @@ -201,6 +204,54 @@ class igemm_fwd_gtc_t { return grid_size; } + // this is to support big tensor > 4G. need to decide how many splits needed + // return the number of splits + int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) + { + int hi = arg->get_int("in_h"); + int wi = arg->get_int("in_w"); + int n = arg->get_int("batchsize"); + int k = arg->get_int("out_channels"); + int c = arg->get_int("in_channels"); + + int stride_h = arg->get_int("conv_stride_h"); + int stride_w = arg->get_int("conv_stride_w"); + int dilation_h = arg->get_int("dilation_h"); + int dilation_w = arg->get_int("dilation_w"); + int pad_h = arg->get_int("pad_h"); + int pad_w = arg->get_int("pad_w"); + int y = arg->get_int("fil_h"); + int x = arg->get_int("fil_w"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + int data_byte = utility_string_to_data_byte(tunable->precision); + size_t image_size_input = static_cast(c) * hi * wi * data_byte; + size_t image_size_output = static_cast(k) * ho * wo * data_byte; + size_t size_4g = 0xffffffffUL; + if(image_size_input >= size_4g || image_size_output >= size_4g) + return 0; + + size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; + size_t splited_n = size_4g / image_size; + + // round up splits, we must match + // 1. splited_n * image_size < size_4g + // 2. n % splited_n == 0 + // if(splited_n >= n) + // return 1; + assert(splited_n != 0); + while(splited_n >= 1){ + // printf("n:%d, splited_n:%d\n", n, splited_n); + if(n % splited_n == 0) + break; + splited_n--; + } + + assert(splited_n * image_size < size_4g && n % splited_n == 0); + return n / splited_n; + } + bool tunable_is_valid(const args_t *arg, const igemm_gtc_tunable_t *tunable) { @@ -224,6 +275,13 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); + int splits = split_batch_size(arg, tunable); + if(splits == 0){ + printf("image size (c*h*w) is bigger than 4g, which is not supported now\n"); + return false; + } + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -375,6 +433,9 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); + int splits = split_batch_size(arg, tunable); + n = n/splits; // split batch size here + int gemm_m_per_block = tunable->gemm_m_per_block; int gemm_n_per_block = tunable->gemm_n_per_block; int gemm_k_per_block = tunable->gemm_k_per_block; @@ -494,7 +555,7 @@ class igemm_fwd_gtc_t { hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); auto launch_fwd = [&]() -> float { - // printf("launch fwd block:%d, grid:%d\n", block_size, grid_size); + // printf("launch fwd block:%d, grid:%dx%d\n", block_size, grid_size, splits); // dump_fwd_karg(&karg); void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, static_cast(&karg_buffer[0]), HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, @@ -508,7 +569,7 @@ class igemm_fwd_gtc_t { hipEventCreate(&stop); // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config, start, stop)); @@ -520,7 +581,7 @@ class igemm_fwd_gtc_t { gpu_timer_t timer(NULL); timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, + HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, block_size, 1, 1, 0, 0, NULL, (void **)&config)); diff --git a/igemm/algo/igemm_fwd_gtc.py b/igemm/algo/igemm_fwd_gtc.py index 5461f3bd..a020f599 100755 --- a/igemm/algo/igemm_fwd_gtc.py +++ b/igemm/algo/igemm_fwd_gtc.py @@ -674,7 +674,8 @@ def __init__(self, mc, outer): sseq = gpr_sequencer_t() self.outer = outer self.s_ka = sym_t('s_ka' , sseq(2)) - self.s_bx = sym_t('s_bx' , sseq(2)) + self.s_bx = sym_t('s_bx' , sseq(1)) + self.s_by = sym_t('s_by' , sseq(1)) self.s_p_in = sym_t('s_p_in' , sseq(4)) self.s_p_wei = sym_t('s_p_wei' , sseq(4)) self.s_p_out = sym_t('s_p_out' , sseq(4)) @@ -1230,6 +1231,7 @@ def get_kernel_code(self): kernel_code = amdgpu_kernel_code_t({ 'enable_sgpr_kernarg_segment_ptr' : 1, 'enable_sgpr_workgroup_id_x' : 1, + 'enable_sgpr_workgroup_id_y' : 1, 'enable_vgpr_workitem_id' : 0, 'workgroup_group_segment_byte_size' : self.tunable.lds_total, 'kernarg_segment_byte_size' : self.karg.get_count(), @@ -1521,6 +1523,22 @@ def emit_kernel_prologue(self): self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_n()}], {igemm_log2(nb_n0)}") self._emit(f"s_mul_i32 s[{s.s_out_stride_n0()}], s[{s.s_out_stride_n()}], s[{s.s_tmp()}]") + # calculate batch split and accumulate the base pointer for input/output + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") + # early init s_knum in case shifted if self.tunable.nxe != 0: self._emit(f"s_mul_i32 s[{s.s_knum()}], s[{s.s_wei_stride_c()}], s[{s.s_c()}]") From 5bb8587b15c7df613b20c54a0f2234fb703fc9f3 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 29 Jan 2021 17:14:03 +0800 Subject: [PATCH 16/40] split 4G support --- driver/gpu_naive_conv/naive_conv.cpp | 474 ++++++++++++++------------- igemm/algo/igemm_fwd_gtc_nhwc.py | 20 +- 2 files changed, 258 insertions(+), 236 deletions(-) diff --git a/driver/gpu_naive_conv/naive_conv.cpp b/driver/gpu_naive_conv/naive_conv.cpp index 5e2496fc..f8bede9d 100644 --- a/driver/gpu_naive_conv/naive_conv.cpp +++ b/driver/gpu_naive_conv/naive_conv.cpp @@ -103,19 +103,19 @@ extern "C" __global__ void naive_conv_fwd_nchw_fp32( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast(p_in[i_idx]) * static_cast(p_wei[w_idx]); } } } } - int o_idx = static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = static_cast(value); } } @@ -171,10 +171,10 @@ extern "C" __global__ void naive_conv_bwd_nchw_fp32( valid_w &= 0; if (valid_h & valid_w) { - int o_idx = static_cast(ik) * ho * wo + - static_cast(cur_ho) * wo + - static_cast(cur_wo); - int f_idx = + size_t o_idx = static_cast(ik) * ho * wo + + static_cast(cur_ho) * wo + + static_cast(cur_wo); + size_t f_idx = static_cast(ik) * c_per_group * fy * fx + static_cast(iy) * fx + static_cast(ix); @@ -184,7 +184,7 @@ extern "C" __global__ void naive_conv_bwd_nchw_fp32( } } } - int i_idx = static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = static_cast(value); } } @@ -234,21 +234,21 @@ extern "C" __global__ void naive_conv_wrw_nchw_fp32( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * c * hi * wi + - static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int o_idx = static_cast(in) * k * ho * wo + - static_cast(iho) * wo + - static_cast(iwo); + size_t i_idx = static_cast(in) * c * hi * wi + + static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t o_idx = static_cast(in) * k * ho * wo + + static_cast(iho) * wo + + static_cast(iwo); value += static_cast(p_in[i_idx]) * static_cast(p_out[o_idx]); } } } } - int f_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = static_cast(value); } } @@ -308,14 +308,16 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp32( valid_w &= 0; if (valid_d & valid_w & valid_h) { - int i_idx = static_cast(ic) * di * hi * wi + - static_cast(cur_d) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = + static_cast(ic) * di * hi * wi + + static_cast(cur_d) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = + static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast(p_in[i_idx]) * static_cast(p_wei[w_idx]); } @@ -323,8 +325,8 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp32( } } } - int o_idx = static_cast(ido) * ho * wo + - static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(ido) * ho * wo + + static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = static_cast(value); } } @@ -394,16 +396,16 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_fp32( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(ik) * do_ * ho * wo + static_cast(cur_do) * ho * wo + static_cast(cur_ho) * wo + static_cast(cur_wo); - int f_idx = static_cast(ik) * c_per_group * - fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t f_idx = static_cast(ik) * + c_per_group * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast(p_out[o_idx]) * static_cast(p_wei[f_idx]); } @@ -411,8 +413,8 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_fp32( } } } - int i_idx = static_cast(idi) * hi * wi + - static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(idi) * hi * wi + + static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = static_cast(value); } } @@ -470,13 +472,13 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_fp32( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * c * di * hi * wi + static_cast(ic) * di * hi * wi + static_cast(cur_d) * hi * wi + static_cast(cur_h) * wi + static_cast(cur_w); - int o_idx = + size_t o_idx = static_cast(in) * k * do_ * ho * wo + static_cast(ido) * ho * wo + static_cast(iho) * wo + @@ -488,9 +490,9 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_fp32( } } } - int f_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = static_cast(value); } } @@ -540,12 +542,12 @@ extern "C" __global__ void naive_conv_fwd_nchw_fp16( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast(__half2float(p_in[i_idx])) * static_cast(__half2float(p_wei[w_idx])); @@ -553,7 +555,7 @@ extern "C" __global__ void naive_conv_fwd_nchw_fp16( } } } - int o_idx = static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = __float2half(static_cast(value)); } } @@ -609,10 +611,10 @@ extern "C" __global__ void naive_conv_bwd_nchw_fp16( valid_w &= 0; if (valid_h & valid_w) { - int o_idx = static_cast(ik) * ho * wo + - static_cast(cur_ho) * wo + - static_cast(cur_wo); - int f_idx = + size_t o_idx = static_cast(ik) * ho * wo + + static_cast(cur_ho) * wo + + static_cast(cur_wo); + size_t f_idx = static_cast(ik) * c_per_group * fy * fx + static_cast(iy) * fx + static_cast(ix); @@ -623,7 +625,7 @@ extern "C" __global__ void naive_conv_bwd_nchw_fp16( } } } - int i_idx = static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = __float2half(static_cast(value)); } } @@ -673,13 +675,13 @@ extern "C" __global__ void naive_conv_wrw_nchw_fp16( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * c * hi * wi + - static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int o_idx = static_cast(in) * k * ho * wo + - static_cast(iho) * wo + - static_cast(iwo); + size_t i_idx = static_cast(in) * c * hi * wi + + static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t o_idx = static_cast(in) * k * ho * wo + + static_cast(iho) * wo + + static_cast(iwo); value += static_cast(__half2float(p_in[i_idx])) * static_cast(__half2float(p_out[o_idx])); @@ -687,8 +689,8 @@ extern "C" __global__ void naive_conv_wrw_nchw_fp16( } } } - int f_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = __float2half(static_cast(value)); } } @@ -748,14 +750,16 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp16( valid_w &= 0; if (valid_d & valid_w & valid_h) { - int i_idx = static_cast(ic) * di * hi * wi + - static_cast(cur_d) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = + static_cast(ic) * di * hi * wi + + static_cast(cur_d) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = + static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast(__half2float(p_in[i_idx])) * static_cast(__half2float(p_wei[w_idx])); @@ -764,8 +768,8 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp16( } } } - int o_idx = static_cast(ido) * ho * wo + - static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(ido) * ho * wo + + static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = __float2half(static_cast(value)); } } @@ -835,16 +839,16 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_fp16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(ik) * do_ * ho * wo + static_cast(cur_do) * ho * wo + static_cast(cur_ho) * wo + static_cast(cur_wo); - int f_idx = static_cast(ik) * c_per_group * - fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t f_idx = static_cast(ik) * + c_per_group * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast( __half2float(p_out[o_idx])) * @@ -854,8 +858,8 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_fp16( } } } - int i_idx = static_cast(idi) * hi * wi + - static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(idi) * hi * wi + + static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = __float2half(static_cast(value)); } } @@ -913,13 +917,13 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_fp16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * c * di * hi * wi + static_cast(ic) * di * hi * wi + static_cast(cur_d) * hi * wi + static_cast(cur_h) * wi + static_cast(cur_w); - int o_idx = + size_t o_idx = static_cast(in) * k * do_ * ho * wo + static_cast(ido) * ho * wo + static_cast(iho) * wo + @@ -932,9 +936,9 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_fp16( } } } - int f_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = __float2half(static_cast(value)); } } @@ -984,12 +988,12 @@ extern "C" __global__ void naive_conv_fwd_nchw_bf16( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( @@ -998,7 +1002,7 @@ extern "C" __global__ void naive_conv_fwd_nchw_bf16( } } } - int o_idx = static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1054,10 +1058,10 @@ extern "C" __global__ void naive_conv_bwd_nchw_bf16( valid_w &= 0; if (valid_h & valid_w) { - int o_idx = static_cast(ik) * ho * wo + - static_cast(cur_ho) * wo + - static_cast(cur_wo); - int f_idx = + size_t o_idx = static_cast(ik) * ho * wo + + static_cast(cur_ho) * wo + + static_cast(cur_wo); + size_t f_idx = static_cast(ik) * c_per_group * fy * fx + static_cast(iy) * fx + static_cast(ix); @@ -1069,7 +1073,7 @@ extern "C" __global__ void naive_conv_bwd_nchw_bf16( } } } - int i_idx = static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1119,13 +1123,13 @@ extern "C" __global__ void naive_conv_wrw_nchw_bf16( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * c * hi * wi + - static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int o_idx = static_cast(in) * k * ho * wo + - static_cast(iho) * wo + - static_cast(iwo); + size_t i_idx = static_cast(in) * c * hi * wi + + static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t o_idx = static_cast(in) * k * ho * wo + + static_cast(iho) * wo + + static_cast(iwo); value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( @@ -1134,8 +1138,8 @@ extern "C" __global__ void naive_conv_wrw_nchw_bf16( } } } - int f_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1195,14 +1199,16 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_bf16( valid_w &= 0; if (valid_d & valid_w & valid_h) { - int i_idx = static_cast(ic) * di * hi * wi + - static_cast(cur_d) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - int w_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t i_idx = + static_cast(ic) * di * hi * wi + + static_cast(cur_d) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + size_t f_idx = + static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( @@ -1212,8 +1218,8 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_bf16( } } } - int o_idx = static_cast(ido) * ho * wo + - static_cast(iho) * wo + static_cast(iwo); + size_t o_idx = static_cast(ido) * ho * wo + + static_cast(iho) * wo + static_cast(iwo); p_out[o_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1283,16 +1289,16 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_bf16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(ik) * do_ * ho * wo + static_cast(cur_do) * ho * wo + static_cast(cur_ho) * wo + static_cast(cur_wo); - int f_idx = static_cast(ik) * c_per_group * - fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + - static_cast(ix); + size_t f_idx = static_cast(ik) * + c_per_group * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); value += static_cast( __bfloat16_to_float(p_out[o_idx])) * static_cast( @@ -1302,8 +1308,8 @@ extern "C" __global__ void naive_conv_bwd_ncdhw_bf16( } } } - int i_idx = static_cast(idi) * hi * wi + - static_cast(ihi) * wi + static_cast(iwi); + size_t i_idx = static_cast(idi) * hi * wi + + static_cast(ihi) * wi + static_cast(iwi); p_in[i_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1361,13 +1367,13 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_bf16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * c * di * hi * wi + static_cast(ic) * di * hi * wi + static_cast(cur_d) * hi * wi + static_cast(cur_h) * wi + static_cast(cur_w); - int o_idx = + size_t o_idx = static_cast(in) * k * do_ * ho * wo + static_cast(ido) * ho * wo + static_cast(iho) * wo + @@ -1381,9 +1387,9 @@ extern "C" __global__ void naive_conv_wrw_ncdhw_bf16( } } } - int f_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); + size_t f_idx = static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + static_cast(ix); p_wei[f_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -1435,10 +1441,10 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp32( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int w_idx = + size_t i_idx = static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -1449,7 +1455,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp32( } } } - int o_idx = static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = static_cast(value); } } @@ -1505,10 +1511,10 @@ extern "C" __global__ void naive_conv_bwd_nhwc_fp32( for (int ik = 0; ik < k_per_group; ik++) { if (valid_h & valid_w) { - int o_idx = static_cast(cur_ho) * wo * k + - static_cast(cur_wo) * k + - static_cast(ik); - int f_idx = + size_t o_idx = static_cast(cur_ho) * wo * k + + static_cast(cur_wo) * k + + static_cast(ik); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -1519,7 +1525,7 @@ extern "C" __global__ void naive_conv_bwd_nhwc_fp32( } } } - int i_idx = static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = static_cast(value); } } @@ -1568,22 +1574,22 @@ extern "C" __global__ void naive_conv_wrw_nhwc_fp32( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int o_idx = static_cast(in) * ho * wo * k + - static_cast(iho) * wo * k + - static_cast(iwo) * k; + size_t i_idx = static_cast(in) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t o_idx = static_cast(in) * ho * wo * k + + static_cast(iho) * wo * k + + static_cast(iwo) * k; value += static_cast(p_in[i_idx]) * static_cast(p_out[o_idx]); } } } } - int f_idx = static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = static_cast(value); } } @@ -1640,12 +1646,12 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp32( valid_w &= 0; for (int ic = 0; ic < c_per_group; ic++) { if (valid_d & valid_w & valid_h) { - int i_idx = + size_t i_idx = static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int w_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -1660,8 +1666,8 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp32( } } } - int o_idx = static_cast(iho) * wo * k + - static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iho) * wo * k + + static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = static_cast(value); } } @@ -1727,12 +1733,12 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_fp32( valid_w &= 0; for (int ik = 0; ik < k_per_group; ik++) { if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(cur_do) * ho * wo * k + static_cast(cur_ho) * wo * k + static_cast(cur_wo) * k + static_cast(ik); - int f_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -1747,8 +1753,8 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_fp32( } } } - int i_idx = static_cast(ihi) * wi * c + - static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(ihi) * wi * c + + static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = static_cast(value); } } @@ -1805,13 +1811,13 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_fp32( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * di * hi * wi * c + static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int o_idx = + size_t o_idx = static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k + static_cast(iho) * wo * k + @@ -1823,10 +1829,10 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_fp32( } } } - int f_idx = static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = static_cast(value); } } @@ -1877,10 +1883,10 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp16( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int w_idx = + size_t i_idx = static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -1892,7 +1898,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp16( } } } - int o_idx = static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = __float2half(static_cast(value)); } } @@ -1948,10 +1954,10 @@ extern "C" __global__ void naive_conv_bwd_nhwc_fp16( for (int ik = 0; ik < k_per_group; ik++) { if (valid_h & valid_w) { - int o_idx = static_cast(cur_ho) * wo * k + - static_cast(cur_wo) * k + - static_cast(ik); - int f_idx = + size_t o_idx = static_cast(cur_ho) * wo * k + + static_cast(cur_wo) * k + + static_cast(ik); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -1963,7 +1969,7 @@ extern "C" __global__ void naive_conv_bwd_nhwc_fp16( } } } - int i_idx = static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = __float2half(static_cast(value)); } } @@ -2012,13 +2018,13 @@ extern "C" __global__ void naive_conv_wrw_nhwc_fp16( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int o_idx = static_cast(in) * ho * wo * k + - static_cast(iho) * wo * k + - static_cast(iwo) * k; + size_t i_idx = static_cast(in) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t o_idx = static_cast(in) * ho * wo * k + + static_cast(iho) * wo * k + + static_cast(iwo) * k; value += static_cast(__half2float(p_in[i_idx])) * static_cast(__half2float(p_out[o_idx])); @@ -2026,9 +2032,9 @@ extern "C" __global__ void naive_conv_wrw_nhwc_fp16( } } } - int f_idx = static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = __float2half(static_cast(value)); } } @@ -2085,12 +2091,12 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp16( valid_w &= 0; for (int ic = 0; ic < c_per_group; ic++) { if (valid_d & valid_w & valid_h) { - int i_idx = + size_t i_idx = static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int w_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -2106,8 +2112,8 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp16( } } } - int o_idx = static_cast(iho) * wo * k + - static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iho) * wo * k + + static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = __float2half(static_cast(value)); } } @@ -2173,12 +2179,12 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_fp16( valid_w &= 0; for (int ik = 0; ik < k_per_group; ik++) { if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(cur_do) * ho * wo * k + static_cast(cur_ho) * wo * k + static_cast(cur_wo) * k + static_cast(ik); - int f_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -2195,8 +2201,8 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_fp16( } } } - int i_idx = static_cast(ihi) * wi * c + - static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(ihi) * wi * c + + static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = __float2half(static_cast(value)); } } @@ -2253,13 +2259,13 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_fp16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * di * hi * wi * c + static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int o_idx = + size_t o_idx = static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k + static_cast(iho) * wo * k + @@ -2272,10 +2278,10 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_fp16( } } } - int f_idx = static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = __float2half(static_cast(value)); } } @@ -2326,10 +2332,10 @@ extern "C" __global__ void naive_conv_fwd_nhwc_bf16( valid_w &= 0; if (valid_w & valid_h) { - int i_idx = static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int w_idx = + size_t i_idx = static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -2342,7 +2348,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_bf16( } } } - int o_idx = static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -2398,10 +2404,10 @@ extern "C" __global__ void naive_conv_bwd_nhwc_bf16( for (int ik = 0; ik < k_per_group; ik++) { if (valid_h & valid_w) { - int o_idx = static_cast(cur_ho) * wo * k + - static_cast(cur_wo) * k + - static_cast(ik); - int f_idx = + size_t o_idx = static_cast(cur_ho) * wo * k + + static_cast(cur_wo) * k + + static_cast(ik); + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + static_cast(iy) * fx * c_per_group + static_cast(ix) * c_per_group + @@ -2414,7 +2420,7 @@ extern "C" __global__ void naive_conv_bwd_nhwc_bf16( } } } - int i_idx = static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -2463,13 +2469,13 @@ extern "C" __global__ void naive_conv_wrw_nhwc_bf16( valid_w &= 0; if (valid_h & valid_w) { - int i_idx = static_cast(in) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + - static_cast(ic); - int o_idx = static_cast(in) * ho * wo * k + - static_cast(iho) * wo * k + - static_cast(iwo) * k; + size_t i_idx = static_cast(in) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + size_t o_idx = static_cast(in) * ho * wo * k + + static_cast(iho) * wo * k + + static_cast(iwo) * k; value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( @@ -2478,9 +2484,9 @@ extern "C" __global__ void naive_conv_wrw_nhwc_bf16( } } } - int f_idx = static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -2537,12 +2543,12 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_bf16( valid_w &= 0; for (int ic = 0; ic < c_per_group; ic++) { if (valid_d & valid_w & valid_h) { - int i_idx = + size_t i_idx = static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int w_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -2559,8 +2565,8 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_bf16( } } } - int o_idx = static_cast(iho) * wo * k + - static_cast(iwo) * k + static_cast(ik); + size_t o_idx = static_cast(iho) * wo * k + + static_cast(iwo) * k + static_cast(ik); p_out[o_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -2626,12 +2632,12 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_bf16( valid_w &= 0; for (int ik = 0; ik < k_per_group; ik++) { if (valid_d & valid_h & valid_w) { - int o_idx = + size_t o_idx = static_cast(cur_do) * ho * wo * k + static_cast(cur_ho) * wo * k + static_cast(cur_wo) * k + static_cast(ik); - int f_idx = + size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + static_cast(iz) * fy * fx * @@ -2648,8 +2654,8 @@ extern "C" __global__ void naive_conv_bwd_ndhwc_bf16( } } } - int i_idx = static_cast(ihi) * wi * c + - static_cast(iwi) * c + static_cast(ic); + size_t i_idx = static_cast(ihi) * wi * c + + static_cast(iwi) * c + static_cast(ic); p_in[i_idx] = __float_to_bfloat16(static_cast(value)); } } @@ -2706,13 +2712,13 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_bf16( valid_w &= 0; if (valid_d & valid_h & valid_w) { - int i_idx = + size_t i_idx = static_cast(in) * di * hi * wi * c + static_cast(cur_d) * hi * wi * c + static_cast(cur_h) * wi * c + static_cast(cur_w) * c + static_cast(ic); - int o_idx = + size_t o_idx = static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k + static_cast(iho) * wo * k + @@ -2726,10 +2732,10 @@ extern "C" __global__ void naive_conv_wrw_ndhwc_bf16( } } } - int f_idx = static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); + size_t f_idx = static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); p_wei[f_idx] = __float_to_bfloat16(static_cast(value)); } } \ No newline at end of file diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 8c0b2184..ce9941ab 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -428,7 +428,8 @@ def __init__(self, mc, outer): sseq = gpr_sequencer_t() self.outer = outer self.s_ka = sym_t('s_ka' , sseq(2)) - self.s_bx = sym_t('s_bx' , sseq(2)) + self.s_bx = sym_t('s_bx' , sseq(1)) + self.s_by = sym_t('s_by' , sseq(1)) self.s_p_in = sym_t('s_p_in' , sseq(4)) self.s_p_wei = sym_t('s_p_wei' , sseq(4)) self.s_p_out = sym_t('s_p_out' , sseq(4)) @@ -812,6 +813,7 @@ def get_kernel_code(self): kernel_code = amdgpu_kernel_code_t({ 'enable_sgpr_kernarg_segment_ptr' : 1, 'enable_sgpr_workgroup_id_x' : 1, + 'enable_sgpr_workgroup_id_y' : 1, 'enable_vgpr_workitem_id' : 0, 'workgroup_group_segment_byte_size' : self.tunable.lds_total, 'kernarg_segment_byte_size' : self.karg.get_count(), @@ -1002,7 +1004,21 @@ def emit_kernel_prologue(self): self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_wo() if self.tunable.nxe != 0 else s.s_wi()}], s[{s.s_out_stride_wo()}]") self._emit(f"s_mul_i32 s[{s.s_out_stride_n()}], s[{s.s_ho() if self.tunable.nxe != 0 else s.s_hi()}], s[{s.s_tmp(1)}]") - # TODO: accumulate splited batch here + # calculate batch split and accumulate the base pointer for input/output + self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_n()}], s[{s.s_in_stride_n()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_n()}], s[{s.s_out_stride_n()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(4)}], s[{s.s_tmp(0)}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_tmp(1)}], {igemm_log2(data_byte)}") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(4)}]") + self._emit(f"s_add_u32 s[{s.s_p_in()}], s[{s.s_p_in()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_by()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_p_out()}], s[{s.s_p_out()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") # early init s_knum in case shifted self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}]") From fc4e716919a1a79afa0e8be54dc67db42735e9fe Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 29 Jan 2021 17:15:21 +0800 Subject: [PATCH 17/40] fix type --- driver/gpu_naive_conv/naive_conv.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/driver/gpu_naive_conv/naive_conv.cpp b/driver/gpu_naive_conv/naive_conv.cpp index f8bede9d..389f8290 100644 --- a/driver/gpu_naive_conv/naive_conv.cpp +++ b/driver/gpu_naive_conv/naive_conv.cpp @@ -110,7 +110,7 @@ extern "C" __global__ void naive_conv_fwd_nchw_fp32( static_cast(iy) * fx + static_cast(ix); value += static_cast(p_in[i_idx]) * - static_cast(p_wei[w_idx]); + static_cast(p_wei[f_idx]); } } } @@ -319,7 +319,7 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp32( static_cast(iy) * fx + static_cast(ix); value += static_cast(p_in[i_idx]) * - static_cast(p_wei[w_idx]); + static_cast(p_wei[f_idx]); } } } @@ -550,7 +550,7 @@ extern "C" __global__ void naive_conv_fwd_nchw_fp16( static_cast(ix); value += static_cast(__half2float(p_in[i_idx])) * - static_cast(__half2float(p_wei[w_idx])); + static_cast(__half2float(p_wei[f_idx])); } } } @@ -762,7 +762,7 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_fp16( static_cast(ix); value += static_cast(__half2float(p_in[i_idx])) * - static_cast(__half2float(p_wei[w_idx])); + static_cast(__half2float(p_wei[f_idx])); } } } @@ -997,7 +997,7 @@ extern "C" __global__ void naive_conv_fwd_nchw_bf16( value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( - __bfloat16_to_float(p_wei[w_idx])); + __bfloat16_to_float(p_wei[f_idx])); } } } @@ -1212,7 +1212,7 @@ extern "C" __global__ void naive_conv_fwd_ncdhw_bf16( value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( - __bfloat16_to_float(p_wei[w_idx])); + __bfloat16_to_float(p_wei[f_idx])); } } } @@ -1450,7 +1450,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp32( static_cast(ix) * c_per_group + static_cast(ic); value += static_cast(p_in[i_idx]) * - static_cast(p_wei[w_idx]); + static_cast(p_wei[f_idx]); } } } @@ -1660,7 +1660,7 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp32( static_cast(ix) * c_per_group + static_cast(ic); value += static_cast(p_in[i_idx]) * - static_cast(p_wei[w_idx]); + static_cast(p_wei[f_idx]); } } } @@ -1893,7 +1893,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_fp16( static_cast(ic); value += static_cast(__half2float(p_in[i_idx])) * - static_cast(__half2float(p_wei[w_idx])); + static_cast(__half2float(p_wei[f_idx])); } } } @@ -2106,7 +2106,7 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_fp16( static_cast(ic); value += static_cast(__half2float(p_in[i_idx])) * - static_cast(__half2float(p_wei[w_idx])); + static_cast(__half2float(p_wei[f_idx])); } } } @@ -2343,7 +2343,7 @@ extern "C" __global__ void naive_conv_fwd_nhwc_bf16( value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( - __bfloat16_to_float(p_wei[w_idx])); + __bfloat16_to_float(p_wei[f_idx])); } } } @@ -2559,7 +2559,7 @@ extern "C" __global__ void naive_conv_fwd_ndhwc_bf16( value += static_cast( __bfloat16_to_float(p_in[i_idx])) * static_cast( - __bfloat16_to_float(p_wei[w_idx])); + __bfloat16_to_float(p_wei[f_idx])); } } } From 425ddbc1bba0d0f2f24f4140c98629265ada6b3a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 29 Jan 2021 17:22:58 +0800 Subject: [PATCH 18/40] fix size --- driver/conv_driver.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index b9aa9815..fb489eb4 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -517,7 +517,7 @@ int main(int argc, char **argv) { fflush(stdout); if (need_verify) - HIP_CALL(hipMemset(device_output, 0, n * k * ho * wo * sizeof(float))); + HIP_CALL(hipMemset(device_output, 0, static_cast(n) * k * ho * wo * sizeof(float))); result_t result = conv_fwd_driver.run(&conv_args, tunable, module, device_input, From ded5a914affd7c3b4be12b58ea93ad44e000de69 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 6 Feb 2021 00:40:12 +0800 Subject: [PATCH 19/40] remove a useless check --- driver/igemm_fwd_gtc_driver.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 250306b0..652e7d53 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -376,11 +376,11 @@ class igemm_fwd_gtc_t { } // input vector load limitation, n1b - if(tunable->tensor_a_thread_lengths[3] > 1 && ( - !unit_conv || - unit_conv && (hi * wi) % tunable->tensor_a_thread_lengths[3] != 0)) { - return false; - } + //if(tunable->tensor_a_thread_lengths[3] > 1 && ( + // !unit_conv || + // unit_conv && (hi * wi) % tunable->tensor_a_thread_lengths[3] != 0)) { + // return false; + //} // // weight vector load limitation, c1e // if(tunable->tensor_a_thread_lengths[1] > 1 && From 20d501bccaaa691ef659655b4d05e30f13c31f24 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 7 Feb 2021 17:14:24 +0800 Subject: [PATCH 20/40] pad gemm_n --- driver/igemm_fwd_gtc_driver.h | 15 +++++--- igemm/algo/global_memory.py | 14 +++++-- igemm/algo/igemm_fwd_gtc_nhwc.py | 64 +++++++++++++++++++++++++------- 3 files changed, 70 insertions(+), 23 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 652e7d53..6d9058e8 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -346,16 +346,19 @@ class igemm_fwd_gtc_t { return false; } }else if(tunable->tensor_layout == "nhwc"){ - int gemm_m = n * b; + //int gemm_m = n * b; // int gemm_n = ((k/group + gemm_n_per_block -1)/gemm_n_per_block) * gemm_n_per_block; - int gemm_n = k / group; - int gemm_k = (c / group) * y * x; + //int gemm_n = k / group; + //int gemm_k = (c / group) * y * x; // support pad to modulo, hence only check when nxe is 0 - if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) - { + //if((gemm_n % gemm_n_per_block != 0) || (gemm_m % gemm_m_per_block != 0)) + //{ + // return false; + //} + + if((c / group) % gemm_k_per_block != 0) return false; - } // if(gemm_m_per_block % tunable->nxb != 0){ // //printf("tunable_is_valid false: gemm_n_per_block%tunable->nxb!=0, gemm_n_per_block is %d, tunable->nxb is %d\n", gemm_n_per_block, tunable->nxb); diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index 68936eb0..92df4181 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -107,6 +107,7 @@ def __init__(self): self.precision = 'fp32' # 'fp32', 'fp16', ... self.src_order = 0 # 0-d0xd1, 1-d1xd0 self.dst_order = 0 # 0-d0xd1, 1-d1xd0 + self.use_flag = 0 self.bfe_flag = 0 class macro_igemm_2d_global_load_t(macro_base_t): @@ -226,6 +227,8 @@ def __init__(self, mc, ctrl, inline = False): self.declare_arg("s_stride_d0") self.declare_arg("s_stride_d1") self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") def name(self): ctrl = self.ctrl @@ -331,6 +334,8 @@ def expr(self): i_soffset = 0 for i_d0 in range(ctrl.length_d0): for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_dst)}]") if i_d0 == 0 and i_d1 == 0: self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os()}", f"{self.s_ptr()}", 0, 0)) elif i_d0 == 0 and i_d1 == 1: @@ -340,6 +345,8 @@ def expr(self): else: self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os()}", f"{self.s_ptr()}", f"{self.s_offset()}+{i_soffset}", 0)) i_soffset += 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") i_dst = i_dst + 1 elif ctrl.src_order == 1 and ctrl.dst_order == 0: @@ -381,7 +388,8 @@ def __init__(self, mc, ctrl, inline = False): self.declare_arg("s_ptr") self.declare_arg("s_os") self.declare_arg("v_os") - self.declare_arg("v_flag") + if self.ctrl.use_flag: + self.declare_arg("v_flag") if self.ctrl.bfe_flag: self.declare_arg("v_tmp") @@ -415,14 +423,14 @@ def expr(self): i_cnt = 0 for i_d0 in range(ctrl.length_d0): for i_d1 in range(n_d1): - if self.v_flag != None: + if ctrl.use_flag and self.v_flag != None: if ctrl.bfe_flag: self._emit(f"v_bfe_u32 v[{self.v_tmp()}], v[{self.v_flag()}], {i_cnt}, 1") self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_tmp()}]") else: self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_cnt)}]") self._emit(buffer_load_dword(f"{self.v_dst()}+{i_cnt*ctrl.vector_d1}", f"{self.v_os(i_cnt)}", f"{self.s_ptr()}", f"{self.s_os()}", 0)) - if self.v_flag != None: + if ctrl.use_flag and self.v_flag != None: self._emit(f"s_mov_b64 exec, -1") i_cnt += 1 diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index ce9941ab..35010592 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -292,8 +292,8 @@ def expr(self): self._emit(m_set_flag_nhw(self.v_tmp(0), self.v_tmp(1), self.v_in_ihi_list(i), self.v_in_iwi_list(i), self.s_hi(), self.s_wi())) self._emit(f"v_lshl_or_b32 v[{self.v_in_flag()}], v[{self.v_tmp(3)}], {i}, v[{self.v_in_flag()}] ; reset flag") else: - self._emit(f"v_bfe_u32 v[{self.v_tmp(1)}], v[{self.v_in_flag_n()}], {i}, 1 ; extract flag_n") - self._emit(m_set_flag_nhw(self.v_in_flag(i), self.v_tmp(1), self.v_in_ihi_list(i), self.v_in_iwi_list(i), self.s_hi(), self.s_wi())) + self._emit(f"v_bfe_u32 v[{self.v_tmp(5)}], v[{self.v_in_flag_n()}], {i}, 1 ; extract flag_n") + self._emit(m_set_flag_nhw(self.v_in_flag(i), self.v_tmp(5), self.v_in_ihi_list(i), self.v_in_iwi_list(i), self.s_hi(), self.s_wi())) self._emit_front(f"{label_acc_yx_end}:") self._emit_empty_line() @@ -340,9 +340,9 @@ def __call__(self): s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.outer.get_symbol_global_load_s_stride_d0_d1() with self._deferred_context(): self._emit(f"; load weight") - # self._emit(f".v_clear_nc {v.v_gld_a()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") if self.outer.tunable.precache_soffset: - self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset())) + self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") + self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), v.v_wei_flag())) else: self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_tmp())) return self._get_deferred() @@ -519,6 +519,7 @@ def __init__(self, mc, outer): ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = outer.get_cluster_lengths() nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 assert nb_per_thread <= 16, "we pack flag into single vgpr" is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS @@ -560,23 +561,26 @@ def __init__(self, mc, outer): self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) self.v_in_inb = sym_t("v_in_inb" ,vseq(1)) self.v_in_in = sym_t("v_in_in" ,vseq(1)) - self.v_wei_ik = sym_t("v_wei_ik" ,vseq(1)) - self.v_co_sst = sym_t("v_co_sst" ,vseq(1)) + self.v_co_sst = sym_t("v_co_sst" ,self.v_in_in.value) self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) if outer.tunable.nxe != 0: - self.v_out_flag = sym_t("v_out_flag" ,vseq(1)) - self.v_out_inb = sym_t("v_out_inb" ,vseq(1)) + self.v_out_flag = sym_t("v_out_flag" ,self.v_wei_ik.value) + self.v_out_inb = sym_t("v_out_inb" ,self.v_in_inb.value) self.v_gemm_in = sym_t("v_gemm_in" ,vseq(1)) self.v_gemm_im = sym_t("v_gemm_im" ,vseq(1)) - - self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,vseq(1)) - self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,vseq(1)) + self.v_co_sub_m_index = sym_t("v_co_sub_m_index" ,self.v_gemm_im.value) + self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,self.v_gemm_in.value) self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + if nk_per_thread <= 4 and IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG == 0: + self.v_wei_flag = sym_t("v_wei_flag" ,self.v_tmp.value) + else: + self.v_wei_flag = sym_t("v_wei_flag" ,vseq(nk_per_thread)) + total_vgpr = vseq() if outer.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if xdlops agpr is larger than vgpr usage, must change vgpr count to agpr @@ -704,6 +708,9 @@ def get_macro_global_load(self): else: assert False + ctrl_in_gld.use_flag = 1 + ctrl_wei_gld.use_flag = 1 + if self.tunable.nxe != 0: if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: ctrl_in_gld.bfe_flag = 1 @@ -938,6 +945,8 @@ def emit_kernel_prologue(self): tc_index_dispatcher = igemm_thread_cluster_index_dispatcher_t(self.mc) tc_index_accumulator = igemm_thread_cluster_index_accumulator_t(self.mc) + nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 + nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 if IGEMM_GTC_FEAT_MAGIC_DIVISION: m_mdiv_u32_vs = macro_mdiv_u32_rem_vs_t(self.mc) @@ -1146,7 +1155,6 @@ def emit_kernel_prologue(self): # voffset if ta_nb0 != 1 or ta_nb1 != 1: thread_stride = na_nb1 if ta_nb0 != 1 else 1 - nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 for i in range(1, nb_per_thread): self._emit(f"s_mov_b32 s1, {thread_stride * i}") @@ -1217,10 +1225,24 @@ def emit_kernel_prologue(self): self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") - self._emit(f"v_add_u32 v[{v.v_tmp(1)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(1)}]") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") + # wei flag + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag()}], 0, 1, vcc") + self._emit(f"v_mov_b32 v[{v.v_b()}], v[{v.v_wei_flag()}]") + + for i in range(1, nk_per_thread): + if i == 1: + k_thread_stride = nb_k1 if tb_k0 != 1 else 1 + self._emit(f"s_mov_b32 s[{s.s_tmp()}], {k_thread_stride}") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_tmp()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag(i)}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_b()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_b()}]") + self._emit_empty_line() if self.wei_thread_copy_ndim != 1: if s_wei_stride_d0 != s_dummy: @@ -1232,6 +1254,8 @@ def emit_kernel_prologue(self): if self.tunable.precache_soffset: self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + # for i in range(nk_per_thread): + # self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i)}], v[{v.v_b()}], {i}, 1") self._emit(self.global_load_wei()) self._emit_empty_line() @@ -1312,11 +1336,18 @@ def emit_kernel_prologue(self): self._emit(f"; move slice stride") self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") + w_flag_cnt = 0 + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(0)}], v[{v.v_b()}], {0}, 1") + w_flag_cnt = w_flag_cnt + 1 + # if self.tunable.nxe != 0: # self._emit(f"s_mov_b32 s[{s.s_tmp()}], {na_c}") # self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_tmp()}], {igemm_log2(data_byte)}") # else: self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") + if w_flag_cnt < nk_per_thread: + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_b()}], {w_flag_cnt}, 1") + w_flag_cnt = w_flag_cnt + 1 if self.tunable.nxe != 0: # s_in_diff_wi : s_dilation_w * s_in_stride_wi @@ -1335,7 +1366,12 @@ def emit_kernel_prologue(self): self._emit_empty_line() self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") + if w_flag_cnt < nk_per_thread: + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_b()}], {w_flag_cnt}, 1") + w_flag_cnt = w_flag_cnt + 1 self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") + for i_w in range(w_flag_cnt, nk_per_thread): + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i_w)}], v[{v.v_b()}], {i_w}, 1") def emit_kernel_fma_main_loop(self): s = self.sgpr From 24d72e7986698ea07f22661af7b141a28302a1a4 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 19 Feb 2021 19:26:15 +0800 Subject: [PATCH 21/40] update nhwc --- igemm/algo/igemm_fwd_gtc_nhwc.py | 34 ++++++++++++++++++++------------ igemm/algo/xdlops_mapping.py | 7 +++++++ igemm/codegen/mc.py | 24 +++++++++++++++++++++- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 35010592..1057b04c 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -313,14 +313,15 @@ def __call__(self): m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() with self._deferred_context(): self._emit(f"; load input, nxe:{self.outer.tunable.nxe}") - if self.outer.tunable.nxe != 0: + #if self.outer.tunable.nxe != 0: + if True: self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), v.v_in_flag(), v.v_tmp())) else: self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), v.v_in_flag())) - else: - self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), None)) + # else: + # self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), None)) return self._get_deferred() @@ -534,7 +535,7 @@ def __init__(self, mc, outer): v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups v_c_needed = (v_c_coalescing_num - v_c_resuable_num) if (v_c_coalescing_num - v_c_resuable_num) > 0 else 0 - v_c_needed = v_c_needed if v_c_needed > 2 else 2 # let at least 2 + v_c_needed = v_c_needed if v_c_needed > 0 else 0 # let at least 0 self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") self.v_a = sym_t("v_a" ,vseq(outer.tunable.num_vgpr_accumulate_a)) @@ -566,8 +567,8 @@ def __init__(self, mc, outer): self.v_co_sst = sym_t("v_co_sst" ,self.v_in_in.value) self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) - if outer.tunable.nxe != 0: - self.v_out_flag = sym_t("v_out_flag" ,self.v_wei_ik.value) + #if outer.tunable.nxe != 0: + self.v_out_flag = sym_t("v_out_flag" ,self.v_wei_ik.value) self.v_out_inb = sym_t("v_out_inb" ,self.v_in_inb.value) self.v_gemm_in = sym_t("v_gemm_in" ,vseq(1)) @@ -697,7 +698,8 @@ def get_macro_global_load(self): ctrl_wei_gld.length_d0 = 1 ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] else: - assert False + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] if self.in_thread_copy_ndim == 2: ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] @@ -706,13 +708,15 @@ def get_macro_global_load(self): ctrl_in_gld.length_d0 = 1 ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] else: - assert False + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] ctrl_in_gld.use_flag = 1 ctrl_wei_gld.use_flag = 1 if self.tunable.nxe != 0: if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + ctrl_wei_gld.bfe_flag = 1 ctrl_in_gld.bfe_flag = 1 if self.tunable.precache_soffset: @@ -802,7 +806,8 @@ def get_symbol_global_load_s_stride_d0_d1(self): s_in_stride_d0 = s_dummy s_in_stride_d1 = in_stride_gprs[in_thread_copy_index[0]] else: - assert False + s_in_stride_d0 = s_dummy + s_in_stride_d1 = in_stride_gprs[-1] if self.wei_thread_copy_ndim == 2: # print(f" ____ wei_thread_copy_index:{len(wei_thread_copy_index)}, {wei_thread_copy_index}") @@ -812,7 +817,8 @@ def get_symbol_global_load_s_stride_d0_d1(self): s_wei_stride_d0 = s_dummy s_wei_stride_d1 = wei_stride_gprs[wei_thread_copy_index[0]] else: - assert False + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = wei_stride_gprs[-1] return s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 @@ -1142,7 +1148,8 @@ def emit_kernel_prologue(self): self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") - if self.tunable.nxe != 0: + #if self.tunable.nxe != 0: + if True: if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag()}], 16, 1") self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) @@ -1306,7 +1313,8 @@ def emit_kernel_prologue(self): self._emit(self.coalescing_store.init_co_sub_n_index(v.v_co_sub_n_index(), '0', v.v_tmp())) self._emit_empty_line() - if self.tunable.nxe != 0: + #if self.tunable.nxe != 0: + if True: self._emit(f"v_add_u32 v[{v.v_tmp()}], s[{s.s_block_gtc_ik()}], v[{v.v_co_sub_n_index()}]") self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp()}]") self._emit(f"v_cndmask_b32 v[{v.v_out_flag()}], 0, 1, vcc") @@ -1567,7 +1575,7 @@ def emit_kernel_epilogue(self): a = self.agpr self._emit(self.coalescing_store(a.a_c(), v.v_c(), v.v_co_sst(), v.v_co_sld(), s.s_p_out(), v.v_out_os(), None, None, s.s_out_stride_wo(), - s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else None, s.s_dim_mr(), v.v_out_inb(), s.s_block_gtc_inb(), v.v_co_sub_m_index(), v.v_tmp())) + s.s_tmp(), v.v_out_flag() if self.tunable.nxe != 0 else v.v_out_flag(), s.s_dim_mr(), v.v_out_inb(), s.s_block_gtc_inb(), v.v_co_sub_m_index(), v.v_tmp())) self._emit_front(f"{self.label_out}:") diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index a9f7a5a7..b9f44711 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -282,15 +282,21 @@ def serialize(self): ctrl_xdlops_mapping_t( 128, 128, 32, 64, 1, 4, 1, 1, 2, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 8 , 1, 4, 2, 2, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 4, 2, 1, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 4, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 64 , 128, 8 , 32, 1, 4, 2, 2, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 4, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 1, 2, 2, 2, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 32 , 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 32, 2, 2, 2, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 32 , 128, 32, 32, 2, 2, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 64 , 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 64 , 32, 32, 2, 4, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), # this is not as good as 16x16x4 @@ -311,6 +317,7 @@ def serialize(self): ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 16 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 16 , 16, 16, 4, 2, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 16 , 64, 16, 1, 1, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), ctrl_xdlops_mapping_t( 16 , 64 , 4 , 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 16 , 64 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 16 , 64 , 16, 16, 4, 2, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), diff --git a/igemm/codegen/mc.py b/igemm/codegen/mc.py index fcd9b36c..8fd6c396 100644 --- a/igemm/codegen/mc.py +++ b/igemm/codegen/mc.py @@ -29,6 +29,11 @@ from copy import deepcopy import subprocess +# NOTE: if following set to True, better parse '-V 0' to conv_driver +# since result can never be correct +MC_DEBUG_IGNORE_LDS_IO = False +MC_DEBUG_IGNORE_GLOBAL_IO = False + class mc_get_version_t(object): def __init__(self): self.called = 0 @@ -147,7 +152,24 @@ def __del__(self): self.close() def emit(self, s): if self.f: - self.f.write(self.indent() + s + '\n') + if MC_DEBUG_IGNORE_LDS_IO or MC_DEBUG_IGNORE_GLOBAL_IO: + s2 = s.split('\n') + ignore_list = list() + if MC_DEBUG_IGNORE_LDS_IO: + ignore_list.extend(['ds_read', 'ds_write', 's_barrier']) + # ignore_list.extend(['ds_write']) + if MC_DEBUG_IGNORE_GLOBAL_IO: + ignore_list.extend(['buffer_load', 's_waitcnt vmcnt']) + for iss, ss in enumerate(s2): + need_emit = True + for i in ignore_list: + if ss.strip().startswith(i): + need_emit = False + break + if need_emit: + self.f.write((self.indent() if iss == 0 else '') + ss + '\n') + else: + self.f.write(self.indent() + s + '\n') def emit_license(self): ''' From 1c62a04284caeaa34aa18a75462a02c9edcee047 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Feb 2021 11:22:21 +0800 Subject: [PATCH 22/40] refactor some code --- igemm/algo/igemm_base.py | 16 ++ igemm/algo/igemm_fwd_gtc_nhwc.py | 251 ++++++++++++++++--------------- igemm/codegen/mbb.py | 11 ++ 3 files changed, 153 insertions(+), 125 deletions(-) diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 0e3e166f..0cdebd0a 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -172,6 +172,8 @@ def __init__(self, tunable_dict): else: assert False + self.tensor_a_pass_through = utility_dict_with_default_t(tunable_dict)('tensor_a_pass_through', 0) + self.tensor_b_pass_through = utility_dict_with_default_t(tunable_dict)('tensor_b_pass_through', 0) self.tensor_a_thread_lengths = tunable_dict['tensor_a_thread_lengths'] # list! self.tensor_a_cluster_lengths = tunable_dict['tensor_a_cluster_lengths'] # list! self.tensor_b_thread_lengths = tunable_dict['tensor_b_thread_lengths'] # list! @@ -356,6 +358,8 @@ def to_dict(self): tunable_dict['wave_tile_k'] = self.wave_tile_k else: assert False + tunable_dict['tensor_a_pass_through'] = self.tensor_a_pass_through + tunable_dict['tensor_b_pass_through'] = self.tensor_b_pass_through tunable_dict['tensor_a_thread_lengths'] = self.tensor_a_thread_lengths tunable_dict['tensor_a_cluster_lengths'] = self.tensor_a_cluster_lengths tunable_dict['tensor_b_thread_lengths'] = self.tensor_b_thread_lengths @@ -417,6 +421,12 @@ def get_dict_with_default(some_dict, key, default_value): line_start + 'wave_step_n {} {}'.format(equal, self.wave_step_n) + new_line + \ line_start + 'wave_repeat_n {} {}'.format(equal, self.wave_repeat_n) + new_line + \ line_start + 'wave_tile_k {} {}'.format(equal, self.wave_tile_k) + new_line + if self.tensor_a_pass_through: + sstr += \ + line_start + 'tensor_a_pass_through {} {}'.format(equal, self.tensor_a_pass_through) + new_line + if self.tensor_b_pass_through: + sstr += \ + line_start + 'tensor_b_pass_through {} {}'.format(equal, self.tensor_b_pass_through) + new_line sstr += \ line_start + 'tensor_a_thread_lengths {} {}'.format(equal, self.tensor_a_thread_lengths) + new_line + \ line_start + 'tensor_a_cluster_lengths {} {}'.format(equal, self.tensor_a_cluster_lengths) + new_line + \ @@ -476,6 +486,12 @@ def lengths_str(lengths): kernel_name += "ta" + lengths_str(tunable.tensor_a_thread_lengths) + "_" + lengths_str(tunable.tensor_a_cluster_lengths) + "_" +\ "tb" + lengths_str(tunable.tensor_b_thread_lengths) + "_" + lengths_str(tunable.tensor_b_cluster_lengths) + if tunable.tensor_a_pass_through: + kernel_name += "_pta" + + if tunable.tensor_b_pass_through: + kernel_name += "_ptb" + if tunable.gemm_m_unmerge_cluster: kernel_name += "_mc" diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 1057b04c..f7f3e1a2 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -116,7 +116,7 @@ def flatten(x): ctrl_coalescing_store_xdlops.block_size = self.tunable.block_size # gemm_m_order, gemm_n_order = self.get_lds_gemm_m_gemm_n_order() - na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() ctrl_coalescing_store_xdlops.gemm_m_m0_m1 = [na_nb0, na_nb1] #if gemm_m_order == IGEMM_FWD_GTC_NHWC_LDS_STORE_ORDER_GEMM_M_N1B_N0: # # we may consider not suppor this mode @@ -425,7 +425,7 @@ def emit(self): class kernel_sgpr_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = outer.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = outer.get_thread_lengths() sseq = gpr_sequencer_t() self.outer = outer self.s_ka = sym_t('s_ka' , sseq(2)) @@ -516,8 +516,8 @@ class kernel_vgpr_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) self.outer = outer - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = outer.get_thread_lengths() - ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = outer.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = outer.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = outer.get_cluster_lengths() nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 @@ -567,7 +567,6 @@ def __init__(self, mc, outer): self.v_co_sst = sym_t("v_co_sst" ,self.v_in_in.value) self.v_co_sld = sym_t("v_co_sld" ,vseq(1)) - #if outer.tunable.nxe != 0: self.v_out_flag = sym_t("v_out_flag" ,self.v_wei_ik.value) self.v_out_inb = sym_t("v_out_inb" ,self.v_in_inb.value) @@ -622,16 +621,21 @@ def get_thread_lengths(self): ta_e, ta_c, ta_nb0, ta_nb1 = t_ta[0], t_ta[1], t_ta[2], t_ta[3] tb_e, tb_c, tb_k0, tb_k1 = t_tb[0], t_tb[1], t_tb[2], t_tb[3] - assert ta_e == tb_e and ta_c == tb_c - assert ta_c in (1, 2, 4), "currently c will be used as LDS store/load vector size, now only support this" + if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through: + pass + else: + assert ta_e == tb_e and ta_c == tb_c + assert ta_c in (1, 2, 4), "currently c will be used as LDS store/load vector size, now only support this" assert ta_e == 1, "currently not support >1 in e dimension" # it's no point to have both x0, x1 have copy value - assert not (ta_nb0 != 1 and ta_nb1 != 1) - assert not (tb_k0 != 1 and tb_k1 != 1) + if not self.tunable.tensor_a_pass_through: + assert not (ta_nb0 != 1 and ta_nb1 != 1) + if not self.tunable.tensor_b_pass_through: + assert not (tb_k0 != 1 and tb_k1 != 1) - return ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 # M, K, N + return ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 # M, K, N def get_cluster_lengths(self): c_ta = self.tunable.tensor_a_cluster_lengths @@ -642,24 +646,28 @@ def get_cluster_lengths(self): ca_e, ca_c, ca_nb0, ca_nb1 = c_ta[0], c_ta[1], c_ta[2], c_ta[3] cb_e, cb_c, cb_k0, cb_k1 = c_tb[0], c_tb[1], c_tb[2], c_tb[3] - assert ca_nb1 != 1 - assert ca_e == cb_e and ca_c == cb_c + if not self.tunable.tensor_a_pass_through: + assert ca_nb1 != 1 + assert ca_e == cb_e and ca_c == cb_c + assert ca_nb0 == 1 + if not self.tunable.tensor_b_pass_through: + assert cb_k0 == 1 - assert ca_e == 1 and ca_nb0 == 1 and cb_k0 == 1 + assert ca_e == 1 - return ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 # M, K, N + return ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 # M, K, N def get_dims_lengths(self): - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() - ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() - na_nb0, na_nb1, na_e, na_c = ta_nb0 * ca_nb0, ta_nb1 * ca_nb1, ta_e * ca_e, ta_c * ca_c - nb_k0, nb_k1 = tb_k0 * cb_k0, tb_k1 * cb_k1 + na_nb0, na_nb1, na_e, na_c = ta_nb0 * ca_nb0, ta_nb1 * ca_nb1, ta_e * ca_e, ta_c * ca_c + nb_k0, nb_k1 , nb_e, nb_c = tb_k0 * cb_k0, tb_k1 * cb_k1, tb_e * cb_e, tb_c * cb_c - return na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 # M, K, N + return na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 # M, K, N def get_thread_copy_dims(self): - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() in_thread_copy_dims = [ta_nb0, ta_nb1, ta_e, ta_c] wei_thread_copy_dims = [tb_k0, tb_k1, ta_e, ta_c] # always reordered! return in_thread_copy_dims, wei_thread_copy_dims @@ -680,8 +688,8 @@ def get_macro_global_load(self): NOTICE: input/wei always load gemm_k (e*c) first. indeed always load c, and do vector load if possible ''' inline = True if self.tunable.fma_interleave else False - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() - na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() @@ -728,8 +736,8 @@ def get_macro_global_load(self): def get_macro_shared_store(self): #in_thread_copy_dims, wei_thread_copy_dims = self.get_thread_copy_dims() #in_thread_copy_index, wei_thread_copy_index = self.get_thread_copy_index() - na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() data_byte = amdgpu_precision_data_byte(self.tunable.precision) k_pack = ta_c # always use this as k_pack @@ -766,7 +774,7 @@ def get_macro_move_slice_window(self): def get_macro_move_slice_window_accumulate(self): inline = True if self.tunable.fma_interleave else False if self.tunable.nxe != 0: - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() nb_per_thread = ta_nb0 if ta_nb0 != 1 else ta_nb1 return self.macro_move_slice_window_block_wise_acc_yx_t(self.mc, self.tunable, inline, label_acc_yx = self.name() + "_acc_yx", @@ -780,7 +788,7 @@ def get_macro_set_flag_nhw(self): return self.macro_set_flag_nhw(self.mc, inline) def get_symbol_global_load_s_stride_d0_d1(self): - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() # get the symbol object that load 2d may use s = self.sgpr s_dummy = sym_t("s_dummy") @@ -937,9 +945,9 @@ def emit_kernel_prologue(self): v = self.vgpr k = self.karg - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() - ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() - na_nb0, na_nb1, na_e, na_c, nb_k0, nb_k1 = self.get_dims_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() + na_nb0, na_nb1, na_e, na_c, nb_e, nb_c, nb_k0, nb_k1 = self.get_dims_lengths() data_byte = amdgpu_precision_data_byte(self.tunable.precision) @@ -1091,16 +1099,16 @@ def emit_kernel_prologue(self): # transform nb self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_inb()}], v[{v.v_in_inb()}]") - if self.tunable.nxe != 0: - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo(), v.v_tmp())) - else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wo(), v.v_tmp(), s.s_tmp())) + if IGEMM_GTC_FEAT_MAGIC_DIVISION: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") + self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp())) + else: + self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) + self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wo() if self.tunable.nxe != 0 else s.s_wi(), v.v_tmp(), s.s_tmp())) + if self.tunable.nxe != 0: # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w self._emit(f"v_mul_lo_u32 v[{v.v_in_ihi_list(0)}], s[{s.s_stride_h()}], v[{v.v_in_ihi_list(0)}]") @@ -1109,16 +1117,6 @@ def emit_kernel_prologue(self): self._emit(f"v_sub_i32 v[{v.v_in_iwi_list(0)}], v[{v.v_in_iwi_list(0)}], s[{s.s_pad_w()}]") self._emit_empty_line() - else: - if IGEMM_GTC_FEAT_MAGIC_DIVISION: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_magic_1(), s.s_tmp(3), s.s_dim_br(), v.v_tmp())) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080010 ; offset:16, width:8") - self._emit(m_mdiv_u32_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_magic_2(), s.s_tmp(3), s.s_wi(), v.v_tmp())) - else: - self._emit(m_int_div_rem_vs(v.v_tmp(4), v.v_in_in(), v.v_tmp(5), s.s_dim_br(), v.v_tmp(), s.s_tmp())) - self._emit(m_int_div_rem_vs(v.v_in_iwi_list(0), v.v_in_ihi_list(0), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) - if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: # update flag for batch size self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_n()}], v[{v.v_in_in()}]") @@ -1130,37 +1128,38 @@ def emit_kernel_prologue(self): self._emit(f"v_lshlrev_b32 v[{v.v_in_flag_n()}], 0, v[{v.v_tmp()}]") self._emit(f"s_lshl_b32 s[{s.s_block_gtc_ig()}], s[{s.s_block_gtc_ig()}], {igemm_log2(data_byte)}") - self._emit(f"; calculate in offset") - self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") - # compute group distance - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") - self._emit(f"s_add_u32 s[{s.s_p_in(0)}], s[{s.s_p_in(0)}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") - self._emit_empty_line() - self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") - # s_in_stride_wi need shift before! - self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) + def calculate_and_load_input(): + self._emit(f"; calculate in offset") + self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") + # compute group distance + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") + self._emit(f"s_add_u32 s[{s.s_p_in(0)}], s[{s.s_p_in(0)}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_in(1)}], s[{s.s_p_in(1)}], s[{s.s_tmp(1)}]") + self._emit_empty_line() - self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") - self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") - self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") - #if self.tunable.nxe != 0: - if True: - if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: - self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag()}], 16, 1") - self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) - self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], 0, v[{v.v_in_flag()}]") - else: - self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag_n()}], 0, 1") - self._emit(m_set_flag_nhw(v.v_in_flag(0), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) - self._emit_empty_line() + self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") + # s_in_stride_wi need shift before! + self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) - # voffset - if ta_nb0 != 1 or ta_nb1 != 1: + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") + self._emit(f"v_add_u32 v[{v.v_in_os()}], v[{v.v_tmp(4)}], v[{v.v_tmp()}]") + + if True: + if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag()}], 16, 1") + self._emit(m_set_flag_nhw(v.v_tmp(), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) + self._emit(f"v_lshl_or_b32 v[{v.v_in_flag()}], v[{v.v_tmp()}], 0, v[{v.v_in_flag()}]") + else: + self._emit(f"v_bfe_u32 v[{v.v_tmp(1)}], v[{v.v_in_flag_n()}], 0, 1") + self._emit(m_set_flag_nhw(v.v_in_flag(0), v.v_tmp(1), v.v_in_ihi_list(0), v.v_in_iwi_list(0), s.s_hi(), s.s_wi())) + self._emit_empty_line() + + # voffset, for [1, nb_per_thread) pixels thread_stride = na_nb1 if ta_nb0 != 1 else 1 for i in range(1, nb_per_thread): @@ -1214,57 +1213,59 @@ def emit_kernel_prologue(self): self._emit(f"v_cndmask_b32 v[{v.v_tmp()}], 0, 1, vcc") self._emit(f"v_lshl_or_b32 v[{v.v_in_flag_n()}], v[{v.v_tmp()}], {i}, v[{v.v_in_flag_n()}]") self._emit(m_set_flag_nhw(v.v_in_flag(i), v.v_tmp(), v.v_in_ihi_list(i), v.v_in_iwi_list(i), s.s_hi(), s.s_wi())) - else: - pass - # load in - self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") - self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") - self._emit(self.global_load_in()) - self._emit_empty_line() - self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") - self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") - - self._emit(f"; calculate wei offset") - self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_k()}], s[{s.s_wei_stride_k()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") - self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") - self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") - self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") - - self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") - self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") - self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") - - # wei flag - self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") - self._emit(f"v_cndmask_b32 v[{v.v_wei_flag()}], 0, 1, vcc") - self._emit(f"v_mov_b32 v[{v.v_b()}], v[{v.v_wei_flag()}]") - - for i in range(1, nk_per_thread): - if i == 1: - k_thread_stride = nb_k1 if tb_k0 != 1 else 1 - self._emit(f"s_mov_b32 s[{s.s_tmp()}], {k_thread_stride}") - self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_tmp()}], v[{v.v_tmp(5)}]") + # load in + self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") + self._emit(self.global_load_in()) + self._emit_empty_line() + + def calculate_and_load_weight(): + self._emit(f"; calculate wei offset") + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_k()}], s[{s.s_wei_stride_k()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_tmp(2)}]") + self._emit(f"s_add_u32 s[{s.s_p_wei()}], s[{s.s_p_wei()}], s[{s.s_tmp()}]") + self._emit(f"s_addc_u32 s[{s.s_p_wei(1)}], s[{s.s_p_wei(1)}], s[{s.s_tmp(1)}]") + + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") + self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") + + # wei flag self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") - self._emit(f"v_cndmask_b32 v[{v.v_wei_flag(i)}], 0, 1, vcc") - self._emit(f"v_lshl_or_b32 v[{v.v_b()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_b()}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag()}], 0, 1, vcc") + self._emit(f"v_mov_b32 v[{v.v_b()}], v[{v.v_wei_flag()}]") + + for i in range(1, nk_per_thread): + if i == 1: + k_thread_stride = nb_k1 if tb_k0 != 1 else 1 + self._emit(f"s_mov_b32 s[{s.s_tmp()}], {k_thread_stride}") + self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_tmp()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_cndmask_b32 v[{v.v_wei_flag(i)}], 0, 1, vcc") + self._emit(f"v_lshl_or_b32 v[{v.v_b()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_b()}]") - self._emit_empty_line() - if self.wei_thread_copy_ndim != 1: - if s_wei_stride_d0 != s_dummy: - self._emit(self.try_shift_stride(s_wei_stride_d0, igemm_log2(data_byte))) - if s_wei_stride_d1 != s_dummy: - self._emit(self.try_shift_stride(s_wei_stride_d1, igemm_log2(data_byte))) - self._emit_empty_line() + self._emit_empty_line() + if self.wei_thread_copy_ndim != 1: + if s_wei_stride_d0 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d0, igemm_log2(data_byte))) + if s_wei_stride_d1 != s_dummy: + self._emit(self.try_shift_stride(s_wei_stride_d1, igemm_log2(data_byte))) + self._emit_empty_line() - if self.tunable.precache_soffset: - self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + if self.tunable.precache_soffset: + self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + + self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") + self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") + self._emit(self.global_load_wei()) + self._emit_empty_line() + + # do load + calculate_and_load_input() + calculate_and_load_weight() - # for i in range(nk_per_thread): - # self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i)}], v[{v.v_b()}], {i}, 1") - self._emit(self.global_load_wei()) - self._emit_empty_line() if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") @@ -1520,7 +1521,7 @@ def move_slice_window_acc(): fctrl.shared_store_a_functor = self.shared_store_wei fctrl.shared_store_b_functor = self.shared_store_in - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() fctrl.lds_k_pack = ta_c if ctrl_xdlops_mapping.wave_step_m == 1: @@ -1560,8 +1561,8 @@ def emit_kernel_epilogue(self): v = self.vgpr #label_out = f"L_{self.name()}_out" - ta_nb0, ta_nb1, ta_e, ta_c, tb_k0, tb_k1 = self.get_thread_lengths() - ca_nb0, ca_nb1, ca_e, ca_c, cb_k0, cb_k1 = self.get_cluster_lengths() + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + ca_nb0, ca_nb1, ca_e, ca_c, cb_e, cb_c, cb_k0, cb_k1 = self.get_cluster_lengths() if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # if self.tunable.nxe != 0: diff --git a/igemm/codegen/mbb.py b/igemm/codegen/mbb.py index 2c778917..929434b3 100644 --- a/igemm/codegen/mbb.py +++ b/igemm/codegen/mbb.py @@ -133,6 +133,17 @@ def dump(self): print(self()) print('-----------------------------------') +def machine_basic_block_call(p, mbb): + ''' + to pretty print mbb, the indent + currently p can not be mc_base_t directly. must be some child class + ''' + mbb_lines = mbb().split('\n') + with p._deferred_context(): + for line in mbb_lines: + p._emit(line) + return p._get_deferred() + def create_machine_basic_block(multi_line_inst_str, **option): ''' an post analysis and construction of mbb, only based on string parse. From 8edb35782fe768d36b2db4ef077b2bc1de954d77 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 26 Feb 2021 00:23:48 +0800 Subject: [PATCH 23/40] further update --- driver/igemm_gtc_base.h | 10 + igemm/algo/global_memory.py | 139 +++++++++++ igemm/algo/igemm_base.py | 28 ++- igemm/algo/igemm_fwd_gtc_nhwc.py | 399 ++++++++++++++++++++---------- igemm/algo/mfma_main_loop.py | 411 ++++++++++++++++++++++++++++++- igemm/algo/xdlops_mapping.py | 29 ++- igemm/codegen/mbb.py | 11 + igemm/codegen/scheduler.py | 4 +- 8 files changed, 882 insertions(+), 149 deletions(-) diff --git a/driver/igemm_gtc_base.h b/driver/igemm_gtc_base.h index e8215f01..c3b4bfac 100755 --- a/driver/igemm_gtc_base.h +++ b/driver/igemm_gtc_base.h @@ -123,6 +123,8 @@ typedef struct { int wave_tile_k; }; }; + int tensor_a_pass_through; + int tensor_b_pass_through; std::vector tensor_a_thread_lengths; std::vector tensor_a_cluster_lengths; std::vector tensor_b_thread_lengths; @@ -187,6 +189,8 @@ igemm_gtc_tunable_from_config(const config_content_t &content) { tunable.wave_repeat_n = sec.at("wave_repeat_n").get_int(); tunable.wave_tile_k = sec.count("wave_tile_k") > 0 ? sec.at("wave_tile_k").get_int() : 1; } + tunable.tensor_a_pass_through = sec.count("tensor_a_pass_through") > 0 ? sec.at("tensor_a_pass_through").get_int() : 0; + tunable.tensor_b_pass_through = sec.count("tensor_b_pass_through") > 0 ? sec.at("tensor_b_pass_through").get_int() : 0; tunable.tensor_a_thread_lengths = sec.at("tensor_a_thread_lengths").get_list_int(); tunable.tensor_a_cluster_lengths = sec.at("tensor_a_cluster_lengths").get_list_int(); tunable.tensor_b_thread_lengths = sec.at("tensor_b_thread_lengths").get_list_int(); @@ -222,6 +226,8 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { // auto gemm_n_per_thread = tunable->gemm_n_per_thread; // auto gemm_n_level0_cluster = tunable->gemm_n_level0_cluster; // auto gemm_n_level1_cluster = tunable->gemm_n_level1_cluster; + auto tensor_a_pass_through = tunable->tensor_a_pass_through; + auto tensor_b_pass_through = tunable->tensor_b_pass_through; auto tensor_a_thread_lengths = tunable->tensor_a_thread_lengths; auto tensor_a_cluster_lengths = tunable->tensor_a_cluster_lengths; auto tensor_b_thread_lengths = tunable->tensor_b_thread_lengths; @@ -296,6 +302,10 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { "tb" + utility_int_list_to_string(tensor_b_thread_lengths) + "_" + utility_int_list_to_string(tensor_b_cluster_lengths); // printf("[%s]\n",kernel_name.c_str()); + if(tensor_a_pass_through) + kernel_name += std::string("_pta"); + if(tensor_b_pass_through) + kernel_name += std::string("_ptb"); if(gemm_m_unmerge_cluster) kernel_name += std::string("_mc"); if(gemm_n_unmerge_cluster) diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index 92df4181..613f3a7b 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -109,6 +109,13 @@ def __init__(self): self.dst_order = 0 # 0-d0xd1, 1-d1xd0 self.use_flag = 0 self.bfe_flag = 0 + self.precache_vs_ptn = 0 # 0: d0 use sgpr precache, d1 use vgpr precache + # 1: d0 use vgpr precache, d1 use sgpr precache + # 2: d0 use vgpr precache, d1 use vgpr precache + # 3: d0 use sgpr precache, d1 use sgpr precache + # 4: .... maybe consider not using precache? + + class macro_igemm_2d_global_load_t(macro_base_t): # TODO: if need vectorize further LDS write, need shuffle dst gpr while load @@ -439,6 +446,138 @@ def get_issues(self): n_d1 = ctrl.length_d1 // ctrl.vector_d1 return ctrl.length_d0 * n_d1 +class macro_igemm_2d_global_load_precache_vs_offset_t(macro_base_t): + # precache voffset for d0 dimension + # precache soffset for d1 dimension + # hence v_flag is along d0 dimension + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("v_os") + self.declare_arg("s_stride_d0") # can be None + self.declare_arg("s_stride_d1") + self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_vs_offset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + if ctrl.src_order == 0 and ctrl.dst_order == 0: + i_dst = 0 + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d0)}]") + current_s_offset = 0 if i_d1 == 0 else (self.s_stride_d1() if i_d1 == 1 else self.s_offset(i_d1 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d0)}", f"{self.s_ptr()}", current_s_offset, 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_dst = i_dst + 1 + + else: + assert False + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + +class macro_igemm_2d_global_load_precache_sv_offset_t(macro_base_t): + # precache soffset for d0 dimension + # precache voffset for d1 dimension + # hence v_flag is along d1 dimension + def __init__(self, mc, ctrl, inline = False): + assert type(ctrl) is ctrl_2d_global_load_t + macro_base_t.__init__(self, mc, inline) + self.ctrl = ctrl + self.declare_arg("v_dst") + self.declare_arg("s_ptr") + self.declare_arg("v_os") + self.declare_arg("s_stride_d0") # can be None + self.declare_arg("s_stride_d1") + self.declare_arg("s_offset") + if self.ctrl.use_flag: + self.declare_arg("v_flag") + if self.ctrl.bfe_flag: + self.declare_arg("v_tmp") + + def name(self): + ctrl = self.ctrl + if ctrl.precision == "fp32": + bits_str = 'b32' + elif ctrl.precision in ("fp16", "bf16"): + bits_str = 'b16' + else: + assert False + + if ctrl.vector_d1 == 4: + vec_str = 'v4' + elif ctrl.vector_d1 == 2: + vec_str = 'v2' + elif ctrl.vector_d1 == 1: + vec_str = 'v1' + else: + assert False + + return f".v_gld_{ctrl.length_d0}x{ctrl.length_d1}_{bits_str}_{vec_str}_precache_sv_offset" + + def expr(self): + ctrl = self.ctrl + assert ctrl.length_d1 % ctrl.vector_d1 == 0 + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + assert ctrl.precision == 'fp32', "TO BE supported" + buffer_load_dword = inst_buffer_load_dword_t(ctrl.vector_d1) + + if ctrl.src_order == 0 and ctrl.dst_order == 0: + i_dst = 0 + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d1)}]") + current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_dst = i_dst + 1 + + else: + assert False + + def get_issues(self): + ctrl = self.ctrl + n_d1 = ctrl.length_d1 // ctrl.vector_d1 + return ctrl.length_d0 * n_d1 + class macro_igemm_write_4d_strided_t(macro_base_t): ''' TODO: this is always not inline diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 0cdebd0a..b9f48444 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -268,6 +268,8 @@ def _unmerge_x1_from_e(unroll_k, nxe): elif self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.local_prefetch_num = 2 if IGEMM_GTC_FEAT_LOCAL_PREFETCH else 1 + if (self.tensor_a_pass_through and self.wave_repeat_n) or (self.tensor_b_pass_through and self.wave_repeat_m): + self.local_prefetch_num = 1 # register for a,b,c buffer xdlops_mapping = get_ctrl_xdlops_mapping_fp32(self.gemm_m_per_block, self.gemm_n_per_block, self.block_size // amdgpu_wave_size(tunable_dict['arch'])) self.num_agpr_accumulate_c = xdlops_mapping.total_acc_c() @@ -282,13 +284,13 @@ def _unmerge_x1_from_e(unroll_k, nxe): assert self.num_vgpr_global_load_b * self.block_size == self.gemm_n_per_block * self.gemm_k_per_block # LDS size - self.lds_a = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_m_per_block - self.lds_b = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_n_per_block - self.lds_a_np2 = igemm_next_pow2( self.lds_a) - self.lds_b_np2 = igemm_next_pow2( self.lds_b) - self.lds_single = igemm_next_pow2( self.lds_a_np2 + self.lds_b_np2) - self.lds_buffer_num = 1 if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS else 2 - self.lds_total = self.lds_buffer_num * self.lds_single + self.lds_a = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_m_per_block if not self.tensor_a_pass_through else 0 + self.lds_b = amdgpu_precision_data_byte(self.precision) * self.gemm_k_per_block * self.gemm_n_per_block if not self.tensor_b_pass_through else 0 + self.lds_a_np2 = igemm_next_pow2( self.lds_a) if self.lds_a != 0 else 0 + self.lds_b_np2 = igemm_next_pow2( self.lds_b) if self.lds_b != 0 else 0 + self.lds_single = igemm_next_pow2( self.lds_a_np2 + self.lds_b_np2) if (self.lds_a_np2 + self.lds_b_np2 != 0) else 0 + self.lds_buffer_num = 1 if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS else 2 + self.lds_total = self.lds_buffer_num * self.lds_single # print(f"lds_a:{self.lds_a}, lds_b:{self.lds_b}, lds_a_np2:{self.lds_a_np2}, lds_b_np2:{self.lds_b_np2}, lds_single:{self.lds_single}, lds_total:{self.lds_total}") # TODO: LDS size check @@ -300,12 +302,15 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.thread_sub_tile_n = self.gemm_n_per_thread # number of loops at least needed for final coalescing store, dicided by LDS size - self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // \ - (self.lds_buffer_num * igemm_next_pow2(igemm_next_pow2(self.gemm_k_per_block * self.gemm_m_per_block) + igemm_next_pow2(self.gemm_k_per_block * self.gemm_n_per_block) )) + # self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // \ + # (self.lds_buffer_num * igemm_next_pow2(igemm_next_pow2(self.gemm_k_per_block * self.gemm_m_per_block) + igemm_next_pow2(self.gemm_k_per_block * self.gemm_n_per_block) )) + self.coalescing_store_groups = (self.gemm_m_per_block * self.gemm_n_per_block) // (self.lds_total // amdgpu_precision_data_byte(self.precision)) + if self.coalescing_store_groups == 0: self.coalescing_store_groups = 1 # this means LDS size is already bigger than c matrix all pixel. just use one group is ok #if self.coalescing_store_groups < 2: # self.coalescing_store_groups = 2 + shrinked_lds_buffer_num = self.lds_buffer_num if self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # check on grouping xdlops_mapping = get_ctrl_xdlops_mapping_fp32(self.gemm_m_per_block, self.gemm_n_per_block, self.block_size // amdgpu_wave_size(tunable_dict['arch'])) @@ -316,8 +321,8 @@ def _unmerge_x1_from_e(unroll_k, nxe): shrink_in_co_group = self.coalescing_store_groups // length_in_m # TODO: this may affect occupancy! - self.lds_buffer_num = self.lds_buffer_num * shrink_in_co_group - self.lds_total = self.lds_buffer_num * self.lds_single + shrinked_lds_buffer_num = shrinked_lds_buffer_num * shrink_in_co_group + self.lds_total = shrinked_lds_buffer_num * self.lds_single self.coalescing_store_groups = self.coalescing_store_groups // shrink_in_co_group def output(self): @@ -448,6 +453,7 @@ def get_dict_with_default(some_dict, key, default_value): line_start + 'thread_tile {} {}x{}'.format(equal, self.thread_tile_m, self.thread_tile_n) + new_line sstr += \ line_start + 'lds_total {} {}'.format(equal, self.lds_total) + new_line + \ + line_start + 'lds_buffer_num {} {}'.format(equal, self.lds_buffer_num) + new_line + \ line_start return sstr diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index f7f3e1a2..c8410272 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -166,16 +166,23 @@ class macro_move_slice_window_block_wise_1x1_t(macro_base_t): def __init__(self, mc, tunable, inline, **options): macro_base_t.__init__(self, mc, True) self.tunable = tunable - self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") # 64bit acc + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset self.declare_arg("v_wei_os") self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight self.options = options def name(self): - return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_1x1' + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_1x1_{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' def expr(self): - self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_base()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_base()}]") + self._emit(f"s_addc_u32 s[{self.s_in_base(1)}], 0, s[{self.s_in_base(1)}]") + else: + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") self._emit_empty_line() @@ -193,7 +200,11 @@ class macro_move_slice_window_block_wise_t(macro_base_t): def __init__(self, mc, tunable, inline, **options): macro_base_t.__init__(self, mc, True) self.tunable = tunable - self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") # 64bit acc + self.declare_arg("s_in_c_itr") + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset self.declare_arg("v_wei_os") self.declare_arg("s_move_slice_k_stride_c") # this is indeed gemm_k * data_byte, same for input/weight self.declare_arg("s_gemm_k_num_c") # c * data_byte @@ -201,12 +212,20 @@ def __init__(self, mc, tunable, inline, **options): self.options = options def name(self): - return '.v_fwd_gtc_nhwc_move_slice_window_block_wise' + return f'.v_fwd_gtc_nhwc_move_slice_window_block_wise_{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' def expr(self): - self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_base()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_base()}]") + self._emit(f"s_addc_u32 s[{self.s_in_base(1)}], 0, s[{self.s_in_base(1)}]") + else: + self._emit(f"s_add_u32 s[{self.s_in_offset()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_offset()}]") self._emit(f"v_add_u32 v[{self.v_wei_os()}], s[{self.s_move_slice_k_stride_c()}], v[{self.v_wei_os()}]") - self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_add_u32 s[{self.s_in_c_itr()}], s[{self.s_move_slice_k_stride_c()}], s[{self.s_in_c_itr()}]") + self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_c_itr()}]") + else: + self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 1, 0") self._emit_empty_line() @@ -218,7 +237,12 @@ class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): def __init__(self, mc, tunable, inline, **options): macro_base_t.__init__(self, mc, True) self.tunable = tunable - self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + if tunable.tensor_a_pass_through: + self.declare_arg("s_in_base") + self.declare_arg("s_in_c_itr") # + self.declare_arg("s_gemm_k_num_c") # used to U64 sub s_in_base, can be None + else: + self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset self.declare_arg("v_in_os") self.declare_arg("v_in_ihi_list") self.declare_arg("v_in_iwi_list") @@ -256,7 +280,12 @@ def expr(self): self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") self._emit(f"s_cbranch_scc0 {label_acc_yx_end} ; no need do accumulate yx") self._emit_front(f"{label_acc_yx}:") - self._emit(f"s_mov_b32 s[{self.s_in_offset()}], 0") # reset input offset. wei, no care + if self.tunable.tensor_a_pass_through: + self._emit(f"s_sub_u32 s[{self.s_in_base()}], s[{self.s_in_base()}], s[{self.s_gemm_k_num_c()}]") + self._emit(f"s_subb_u32 s[{self.s_in_base(1)}], s[{self.s_in_base(1)}], 0") + self._emit(f"s_mov_b32 s[{self.s_in_c_itr()}], 0") # reset input offset. wei, no care + else: + self._emit(f"s_mov_b32 s[{self.s_in_offset()}], 0") # reset input offset. wei, no care ''' ix accumulate, will only accumulate in width, and will never carry on to height iy accumulate, will only accumulate in height, and will never carry on to batch @@ -303,25 +332,28 @@ def __init__(self, mc, outer): mc_base_t.__init__(self, mc) self.outer = outer def get_issues(self): - m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() + m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() return m_in_2d_global_load.get_issues() def __call__(self): s = self.outer.sgpr v = self.outer.vgpr + tunable = self.outer.tunable m_wei_2d_global_load, m_in_2d_global_load = self.outer.get_macro_global_load() with self._deferred_context(): self._emit(f"; load input, nxe:{self.outer.tunable.nxe}") #if self.outer.tunable.nxe != 0: - if True: - self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") + # if tunable.tensor_a_pass_through: + self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") + if tunable.tensor_a_pass_through: + self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), None, s.s_in_stride_k_pack(), s.s_in_offset(), + *(v.v_in_flag(), v.v_tmp()) if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG else (v.v_in_flag(),))) + else: if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), v.v_in_flag(), v.v_tmp())) else: self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), v.v_in_flag())) - # else: - # self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), s.s_in_offset(), v.v_in_os(), None)) return self._get_deferred() @@ -342,7 +374,7 @@ def __call__(self): with self._deferred_context(): self._emit(f"; load weight") if self.outer.tunable.precache_soffset: - self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") + # self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), v.v_wei_flag())) else: self._emit(m_wei_2d_global_load(v.v_gld_b(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_tmp())) @@ -426,6 +458,7 @@ class kernel_sgpr_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = outer.get_thread_lengths() + k_pack = outer.get_k_pack() sseq = gpr_sequencer_t() self.outer = outer self.s_ka = sym_t('s_ka' , sseq(2)) @@ -484,10 +517,17 @@ def __init__(self, mc, outer): self.s_move_slice_k_ix = sym_t("s_move_slice_k_ix" , sseq(1)) self.s_flag_need_acc_yx = sym_t("s_flag_need_acc_yx" , sseq(1)) self.s_kitr = sym_t("s_kitr" , 1) - self.s_in_offset = sym_t("s_in_offset" , sseq(1)) + if outer.tunable.tensor_a_pass_through: + # need s precache + in_npc = ((ta_c // k_pack) - 2) if ((ta_c // k_pack) - 2 > 0 ) else 0 + + self.s_in_offset = sym_t("s_in_offset" , sseq(in_npc)) + self.s_in_c_itr = sym_t("s_in_c_itr" , 2) + self.s_in_stride_k_pack = sym_t("s_in_stride_k_pack" , sseq(1)) + else: + self.s_in_offset = sym_t("s_in_offset" , sseq(1)) if outer.tunable.precache_soffset: m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() - #in_npc = m_in_2d_global_load.get_num_precache_soffset() wei_npc = m_wei_2d_global_load.get_num_precache_soffset() self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) @@ -523,13 +563,18 @@ def __init__(self, mc, outer): nk_per_thread = tb_k0 if tb_k0 != 1 else tb_k1 assert nb_per_thread <= 16, "we pack flag into single vgpr" + k_pack = outer.get_k_pack() + share_load_packed = k_pack if outer.tunable.tensor_a_pass_through or outer.tunable.tensor_b_pass_through else 1 + is_vgpr_acc_c = outer.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS vseq = gpr_sequencer_t() + num_vgpr_acc_a = share_load_packed * outer.tunable.num_vgpr_accumulate_a if not outer.tunable.tensor_a_pass_through else 0 + num_vgpr_acc_b = share_load_packed * outer.tunable.num_vgpr_accumulate_b if not outer.tunable.tensor_b_pass_through else 0 if is_vgpr_acc_c: self.v_c = sym_t("v_c" ,vseq(outer.tunable.num_vgpr_accumulate_c)) v_c_num = vseq() else: - v_c_resuable_num = outer.tunable.num_vgpr_accumulate_a + outer.tunable.num_vgpr_accumulate_b + \ + v_c_resuable_num = num_vgpr_acc_a + num_vgpr_acc_b + \ outer.tunable.num_vgpr_global_load_a + outer.tunable.num_vgpr_global_load_b + \ 16 # from v_sst_a_os to v_co_sst v_c_coalescing_num = outer.tunable.num_agpr_accumulate_c // outer.coalescing_store_groups @@ -538,14 +583,18 @@ def __init__(self, mc, outer): v_c_needed = v_c_needed if v_c_needed > 0 else 0 # let at least 0 self.v_c = sym_t("v_c" ,vseq(v_c_needed), f"coalescing:{v_c_coalescing_num}, needed:{v_c_needed}, resuable:{v_c_resuable_num}") - self.v_a = sym_t("v_a" ,vseq(outer.tunable.num_vgpr_accumulate_a)) - self.v_b = sym_t("v_b" ,vseq(outer.tunable.num_vgpr_accumulate_b)) + if not outer.tunable.tensor_a_pass_through: + self.v_a = sym_t("v_a" ,vseq(num_vgpr_acc_a)) + if not outer.tunable.tensor_b_pass_through: + self.v_b = sym_t("v_b" ,vseq(num_vgpr_acc_b)) self.v_gld_a = sym_t("v_gld_a" ,vseq(outer.tunable.num_vgpr_global_load_a)) self.v_gld_b = sym_t("v_gld_b" ,vseq(outer.tunable.num_vgpr_global_load_b)) - self.v_sst_a_os = sym_t("v_sst_a_os" ,vseq(1)) - self.v_sst_b_os = sym_t("v_sst_b_os" ,vseq(1)) - self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) - self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) + if not outer.tunable.tensor_a_pass_through: + self.v_sst_a_os = sym_t("v_sst_a_os" ,vseq(1)) + self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) + if not outer.tunable.tensor_b_pass_through: + self.v_sst_b_os = sym_t("v_sst_b_os" ,vseq(1)) + self.v_sld_b_os = sym_t("v_sld_b_os" ,vseq(1)) self.v_in_os = sym_t("v_in_os" ,vseq(nb_per_thread)) self.v_in_ihi_list = sym_t("v_in_ihi_list" ,vseq(nb_per_thread)) @@ -559,7 +608,12 @@ def __init__(self, mc, outer): self.v_wei_os = sym_t("v_wei_os" ,vseq(1)) self.v_out_os = sym_t("v_out_os" ,vseq(1)) - self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) + if outer.tunable.tensor_a_pass_through: + self.v_gtc_ic_a = sym_t("v_gtc_ic_a" ,self.v_gld_a.value) + if outer.tunable.tensor_b_pass_through: + self.v_gtc_ic_b = sym_t("v_gtc_ic_b" ,self.v_gld_b.value) + if not (outer.tunable.tensor_a_pass_through and outer.tunable.tensor_b_pass_through): + self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) self.v_in_inb = sym_t("v_in_inb" ,vseq(1)) self.v_in_in = sym_t("v_in_in" ,vseq(1)) self.v_wei_ik = sym_t("v_wei_ik" ,vseq(1)) @@ -576,6 +630,7 @@ def __init__(self, mc, outer): self.v_co_sub_n_index = sym_t("v_co_sub_n_index" ,self.v_gemm_in.value) self.v_tmp = sym_t("v_tmp" ,vseq(6, 2)) + self.v_wei_tmp_pack = sym_t("v_wei_tmp_pack" ,self.v_gld_a.value - 1 if self.v_gld_a.value > 1 else vseq(1)) if nk_per_thread <= 4 and IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG == 0: self.v_wei_flag = sym_t("v_wei_flag" ,self.v_tmp.value) else: @@ -669,7 +724,7 @@ def get_dims_lengths(self): def get_thread_copy_dims(self): ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() in_thread_copy_dims = [ta_nb0, ta_nb1, ta_e, ta_c] - wei_thread_copy_dims = [tb_k0, tb_k1, ta_e, ta_c] # always reordered! + wei_thread_copy_dims = [tb_k0, tb_k1, tb_e, tb_c] # always reordered! return in_thread_copy_dims, wei_thread_copy_dims def get_thread_copy_index(self): @@ -683,6 +738,20 @@ def get_thread_copy_index(self): ''' return in_thread_copy_index, wei_thread_copy_index + def get_k_pack(self): + ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + if (not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through) or \ + (self.tunable.tensor_a_pass_through and self.tunable.tensor_b_pass_through): + assert ta_c == tb_c + return tb_c + else: + if self.tunable.tensor_a_pass_through: + assert ta_c % tb_c == 0 + return tb_c + else: + assert tb_c % ta_c == 0 + return ta_c + def get_macro_global_load(self): ''' NOTICE: input/wei always load gemm_k (e*c) first. indeed always load c, and do vector load if possible @@ -696,28 +765,38 @@ def get_macro_global_load(self): ctrl_wei_gld = ctrl_2d_global_load_t() ctrl_in_gld = ctrl_2d_global_load_t() - ctrl_wei_gld.vector_d1 = utility_gcd(ta_c, 4) if ta_c != 1 else 1 + ctrl_wei_gld.vector_d1 = utility_gcd(tb_c, 4) if tb_c != 1 else 1 ctrl_in_gld.vector_d1 = utility_gcd(ta_c, 4) if ta_c != 1 else 1 - if self.wei_thread_copy_ndim == 2: - ctrl_wei_gld.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] - ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] - elif self.wei_thread_copy_ndim == 1: - ctrl_wei_gld.length_d0 = 1 - ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + if self.tunable.tensor_b_pass_through: + ctrl_wei_gld.length_d0 = tb_k0 if tb_k0 != 1 else tb_k1 + ctrl_wei_gld.length_d1 = tb_c + ctrl_wei_gld.vector_d1 = self.get_k_pack() else: - ctrl_wei_gld.length_d0 = 1 - ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] + if self.wei_thread_copy_ndim == 2: + ctrl_wei_gld.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[1]] + elif self.wei_thread_copy_ndim == 1: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[wei_thread_copy_index[0]] + else: + ctrl_wei_gld.length_d0 = 1 + ctrl_wei_gld.length_d1 = wei_thread_copy_dims[-1] - if self.in_thread_copy_ndim == 2: - ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] - ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] - elif self.in_thread_copy_ndim == 1: - ctrl_in_gld.length_d0 = 1 - ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + if self.tunable.tensor_a_pass_through: + ctrl_in_gld.length_d0 = ta_c // self.get_k_pack() + ctrl_in_gld.length_d1 = (ta_nb0 if ta_nb0 != 1 else ta_nb1) * self.get_k_pack() + ctrl_in_gld.vector_d1 = self.get_k_pack() else: - ctrl_in_gld.length_d0 = 1 - ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] + if self.in_thread_copy_ndim == 2: + ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[1]] + elif self.in_thread_copy_ndim == 1: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[in_thread_copy_index[0]] + else: + ctrl_in_gld.length_d0 = 1 + ctrl_in_gld.length_d1 = in_thread_copy_dims[-1] ctrl_in_gld.use_flag = 1 ctrl_wei_gld.use_flag = 1 @@ -728,7 +807,9 @@ def get_macro_global_load(self): ctrl_in_gld.bfe_flag = 1 if self.tunable.precache_soffset: - return macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ + return macro_igemm_2d_global_load_precache_sv_offset_t(self.mc, ctrl_wei_gld, inline) if self.tunable.tensor_b_pass_through else \ + macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ + macro_igemm_2d_global_load_precache_sv_offset_t(self.mc, ctrl_in_gld, inline) if self.tunable.tensor_a_pass_through else \ macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) else: return macro_igemm_2d_global_load_t(self.mc, ctrl_wei_gld, inline), macro_igemm_2d_global_load_precache_voffset_t(self.mc, ctrl_in_gld, inline) @@ -740,26 +821,29 @@ def get_macro_shared_store(self): ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() data_byte = amdgpu_precision_data_byte(self.tunable.precision) - k_pack = ta_c # always use this as k_pack + k_pack = self.get_k_pack() - # input is gemm_k * gemm_m * k_pack - in_sst_ctrl = ctrl_3d_shared_store_t() - in_sst_ctrl.length_d0 = ta_nb0 - in_sst_ctrl.length_d1 = ta_nb1 - in_sst_ctrl.length_dp = k_pack - in_sst_ctrl.stride_d0 = na_nb1 * k_pack * data_byte - in_sst_ctrl.stride_d1 = k_pack * data_byte + if not self.tunable.tensor_a_pass_through: + # input is gemm_k * gemm_m * k_pack + in_sst_ctrl = ctrl_3d_shared_store_t() + in_sst_ctrl.length_d0 = ta_nb0 + in_sst_ctrl.length_d1 = ta_nb1 + in_sst_ctrl.length_dp = k_pack + in_sst_ctrl.stride_d0 = na_nb1 * k_pack * data_byte + in_sst_ctrl.stride_d1 = k_pack * data_byte - # wei is gemm_k * gemm_n * k_pack - wei_sst_ctrl = ctrl_3d_shared_store_t() - wei_sst_ctrl.length_d0 = tb_k0 - wei_sst_ctrl.length_d1 = tb_k1 - wei_sst_ctrl.length_dp = k_pack - wei_sst_ctrl.stride_d0 = nb_k1 * k_pack * data_byte - wei_sst_ctrl.stride_d1 = k_pack * data_byte + if not self.tunable.tensor_b_pass_through: + # wei is gemm_k * gemm_n * k_pack + wei_sst_ctrl = ctrl_3d_shared_store_t() + wei_sst_ctrl.length_d0 = tb_k0 + wei_sst_ctrl.length_d1 = tb_k1 + wei_sst_ctrl.length_dp = k_pack + wei_sst_ctrl.stride_d0 = nb_k1 * k_pack * data_byte + wei_sst_ctrl.stride_d1 = k_pack * data_byte inline = True if self.tunable.fma_interleave else False - return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline), macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) + return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None, \ + macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) if not self.tunable.tensor_b_pass_through else None def get_macro_move_slice_window(self): inline = True if self.tunable.fma_interleave else False @@ -930,12 +1014,12 @@ def get_kernel_macros(self): if type(rtn) is tuple: for e in rtn: #if hasattr(e, 'emit'): - if not e.is_inline(): + if e is not None and not e.is_inline(): #continue kernel_macros.extend([m for m in rtn]) else: #if hasattr(rtn, 'emit'): - if not e.is_inline(): + if rtn is not None and rtn.is_inline(): #continue kernel_macros.append(rtn) return kernel_macros @@ -972,7 +1056,7 @@ def emit_kernel_prologue(self): s_dummy = sym_t("s_dummy") - k_pack = ta_c # always use this as k_pack + k_pack = self.get_k_pack() # start emit self._emit(f"s_load_dwordx2 s[{s.s_p_in((0,1))}], s[{s.s_ka((0, 1))}], 0+{k.k_p_in()}") @@ -991,15 +1075,28 @@ def emit_kernel_prologue(self): self._emit(f"s_load_dwordx2 s[{s.s_magic_2((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") self._emit(f"s_load_dword s[{s.s_shift_pack_0()}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") - self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}") + self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") - self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), ca_c, ta_c)) - self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1, True)) + if self.tunable.tensor_a_pass_through: + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1)) + self._emit(tc_index_dispatcher(v.v_gtc_ic_a(), v.v_tmp(), ca_c, k_pack)) # <= note here, thread length is further reduced! + self._emit(tc_index_dispatcher(v.v_tmp(1), v.v_tmp(), ca_nb0, ta_nb0, True)) + self._emit(tc_index_accumulator(v.v_in_inb(), v.v_tmp(1), v.v_in_inb(), ca_nb0, ca_nb1, na_nb0, na_nb1)) + else: + self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), ca_c, ta_c)) + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), ca_nb1, ta_nb1, True)) - self._emit(f"; wei(e, c, k0, k1) thread_length: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, cluster_length: {ca_e}x{ca_c}x{cb_k0}x{cb_k1}") + self._emit(f"; wei(e, c, k0, k1) thread_length: {tb_e}x{tb_c}x{tb_k0}x{tb_k1}, cluster_length: {cb_e}x{cb_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") # weight ic same as input - self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") - self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) + if (not self.tunable.tensor_a_pass_through) and (not self.tunable.tensor_b_pass_through): + self._emit(f"v_lshrrev_b32 v[{v.v_tmp()}], {igemm_log2(ca_c)}, v0") + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) + elif self.tunable.tensor_a_pass_through: + self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") + self._emit(tc_index_dispatcher(v.v_gtc_ic(), v.v_tmp(), cb_c, tb_c)) + self._emit(tc_index_dispatcher(v.v_wei_ik(), v.v_tmp(), cb_k1, tb_k1, True)) + else: + assert False, "unimplemented" self._emit_empty_line() @@ -1131,7 +1228,13 @@ def emit_kernel_prologue(self): def calculate_and_load_input(): self._emit(f"; calculate in offset") - self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") + if self.tunable.tensor_a_pass_through: + self._emit(f"s_mov_b32 s[{s.s_in_c_itr()}], 0") + self._emit(f"s_mov_b32 s[{s.s_in_stride_k_pack()}], {ca_c * k_pack * data_byte}") + for i in range(2, ta_c // k_pack): + self._emit(f"s_mul_i32 s[{s.s_in_offset(i - 2)}], s[{s.s_in_stride_k_pack()}], {i}") + else: + self._emit(f"s_mov_b32 s[{s.s_in_offset()}], 0") # compute group distance self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") self._emit(f"s_mul_hi_u32 s[{s.s_tmp(1)}], s[{s.s_block_gtc_ig()}], s[{s.s_c()}]") @@ -1143,7 +1246,7 @@ def calculate_and_load_input(): # s_in_stride_wi need shift before! self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) - self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") @@ -1160,7 +1263,10 @@ def calculate_and_load_input(): self._emit_empty_line() # voffset, for [1, nb_per_thread) pixels - thread_stride = na_nb1 if ta_nb0 != 1 else 1 + if self.tunable.tensor_a_pass_through: + thread_stride = ca_nb0 * ca_nb1 + else: + thread_stride = na_nb1 if ta_nb0 != 1 else 1 for i in range(1, nb_per_thread): self._emit(f"s_mov_b32 s1, {thread_stride * i}") @@ -1195,7 +1301,7 @@ def calculate_and_load_input(): self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") - self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(i)}]") self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(i)}], v[{v.v_tmp()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_in_stride_wi()}], v[{v.v_tmp()}]") @@ -1217,7 +1323,14 @@ def calculate_and_load_input(): # load in self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") - self._emit(self.global_load_in()) + if self.tunable.tensor_a_pass_through: + mbb_gld_in = create_machine_basic_block(self.global_load_in()) + gld_per_k = self.tunable.wave_repeat_m * self.tunable.wave_step_m + for i_mbb in mbb_gld_in[0:(-1 * gld_per_k)]: + # TODO: need multiple load of pass through side + self._emit(machine_basic_block_call(self, i_mbb)) + else: + self._emit(self.global_load_in()) self._emit_empty_line() def calculate_and_load_weight(): @@ -1235,7 +1348,7 @@ def calculate_and_load_weight(): # wei flag self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") self._emit(f"v_cndmask_b32 v[{v.v_wei_flag()}], 0, 1, vcc") - self._emit(f"v_mov_b32 v[{v.v_b()}], v[{v.v_wei_flag()}]") + self._emit(f"v_mov_b32 v[{v.v_wei_tmp_pack()}], v[{v.v_wei_flag()}]") for i in range(1, nk_per_thread): if i == 1: @@ -1244,7 +1357,7 @@ def calculate_and_load_weight(): self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_tmp()}], v[{v.v_tmp(5)}]") self._emit(f"v_cmp_gt_u32 vcc, s[{s.s_k()}], v[{v.v_tmp(5)}]") self._emit(f"v_cndmask_b32 v[{v.v_wei_flag(i)}], 0, 1, vcc") - self._emit(f"v_lshl_or_b32 v[{v.v_b()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_b()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_wei_tmp_pack()}], v[{v.v_wei_flag(i)}], {i}, v[{v.v_wei_tmp_pack()}]") self._emit_empty_line() if self.wei_thread_copy_ndim != 1: @@ -1257,54 +1370,73 @@ def calculate_and_load_weight(): if self.tunable.precache_soffset: self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") - self._emit(self.global_load_wei()) + if self.tunable.tensor_b_pass_through: + mbb_gld_wei = create_machine_basic_block(self.global_load_wei()) + gld_per_k = self.tunable.wave_repeat_n * self.tunable.wave_step_n + for i_mbb in mbb_gld_wei[0:(-1 * gld_per_k)]: + # TODO: need multiple load of pass through side + self._emit(machine_basic_block_call(self, i_mbb)) + else: + self._emit(self.global_load_wei()) self._emit_empty_line() # do load - calculate_and_load_input() calculate_and_load_weight() + calculate_and_load_input() + if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") self._emit(self.thread_mapping(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp())) else: + v_pack = k_pack if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through else 1 self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") - self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp(), k_pack)) + self._emit(self.xdlops_mapping.get_gemm_index_for_src_matrix(v.v_gemm_in(), v.v_gemm_im(), v.v_tmp(5), v.v_tmp(), + k_pack=k_pack, v_pack=v_pack)) self._emit(f"v_mov_b32 v[{v.v_tmp(5)}], v0") self._emit(self.xdlops_mapping.get_gemm_index_for_dst_matrix(v.v_co_sst(), v.v_co_sld(), v.v_tmp(5), v.v_tmp())) ''' gemm_k * gemm_m * k_pack ''' - self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") - if k_pack != 1: - self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_in_inb()}]") - self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_tmp(2)}]") - else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_in_inb()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") - self._emit_empty_line() + if not self.tunable.tensor_a_pass_through: + self._emit(f"; LDS store, in: e,c,nb0,nb1: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_tmp(2)}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(na_nb0*na_nb1 * k_pack)}, v[{v.v_in_inb()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_sst_a_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + self._emit_empty_line() + self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}] ; LDS load in") - self._emit(f"; LDS store, wei: e,c,k: {ta_e}x{ta_c}x{tb_k0}x{tb_k1}, {ca_e}x{ca_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") - if k_pack != 1: - self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_wei_ik()}]") - self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_tmp(2)}]") - else: - self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_wei_ik()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") - self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") - self._emit_empty_line() + if not self.tunable.tensor_b_pass_through: + self._emit(f"; LDS store, wei: e,c,k: {tb_e}x{tb_c}x{tb_k0}x{tb_k1}, {cb_e}x{cb_c}x{cb_k0}x{cb_k1}, k_pack:{k_pack}") + if k_pack != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {igemm_log2(k_pack)}, v[{v.v_wei_ik()}]") + self._emit(f"v_lshrrev_b32 v[{v.v_tmp(1)}], {igemm_log2(k_pack)}, v[{v.v_gtc_ic()}]") + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_tmp(1)}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_tmp(2)}]") + else: + self._emit(f"v_lshl_or_b32 v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(nb_k0*nb_k1 * k_pack)}, v[{v.v_wei_ik()}]") + self._emit(f"v_lshlrev_b32 v[{v.v_sst_b_os()}], {igemm_log2(data_byte)}, v[{v.v_tmp()}]") + if not self.tunable.tensor_a_pass_through: + self._emit(f"v_add_u32 v[{v.v_sst_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sst_b_os()}]") + self._emit_empty_line() + self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}] ; LDS load wei") + if not self.tunable.tensor_a_pass_through: + self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") - self._emit(f"; LDS load") - self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}]") - self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}]") - self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") - self._emit_empty_line() + # self._emit(f"; LDS load") + #if not self.tunable.tensor_a_pass_through: + #self._emit(f"v_lshlrev_b32 v[{v.v_sld_b_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_in()}]") + #self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {igemm_log2(data_byte)}, v[{v.v_gemm_im()}]") + #self._emit(f"v_add_u32 v[{v.v_sld_b_os()}], {self.tunable.lds_a_np2}, v[{v.v_sld_b_os()}]") + #self._emit_empty_line() if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self._emit(f"v_mov_b32 v[{v.v_gemm_in()}], v[{v.v_co_sst()}]") @@ -1346,7 +1478,7 @@ def calculate_and_load_weight(): self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") w_flag_cnt = 0 - self._emit(f"v_bfe_u32 v[{v.v_wei_flag(0)}], v[{v.v_b()}], {0}, 1") + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(0)}], v[{v.v_wei_tmp_pack()}], {0}, 1") w_flag_cnt = w_flag_cnt + 1 # if self.tunable.nxe != 0: @@ -1355,7 +1487,7 @@ def calculate_and_load_weight(): # else: self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_c()}], {na_c * data_byte}") if w_flag_cnt < nk_per_thread: - self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_b()}], {w_flag_cnt}, 1") + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_wei_tmp_pack()}], {w_flag_cnt}, 1") w_flag_cnt = w_flag_cnt + 1 if self.tunable.nxe != 0: @@ -1376,16 +1508,19 @@ def calculate_and_load_weight(): self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") if w_flag_cnt < nk_per_thread: - self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_b()}], {w_flag_cnt}, 1") + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(w_flag_cnt)}], v[{v.v_wei_tmp_pack()}], {w_flag_cnt}, 1") w_flag_cnt = w_flag_cnt + 1 self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x27000") for i_w in range(w_flag_cnt, nk_per_thread): - self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i_w)}], v[{v.v_b()}], {i_w}, 1") + self._emit(f"v_bfe_u32 v[{v.v_wei_flag(i_w)}], v[{v.v_wei_tmp_pack()}], {i_w}, 1") def emit_kernel_fma_main_loop(self): s = self.sgpr v = self.vgpr + k = self.karg + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + k_pack = self.get_k_pack() m_move_slice_window = self.get_macro_move_slice_window() m_move_slice_window_accumulate = self.get_macro_move_slice_window_accumulate() @@ -1397,7 +1532,7 @@ def move_slice_window_b(): if self.tunable.nxe != 0: with self._deferred_context(): self._emit(m_move_slice_window( - s.s_in_offset(), + *(s.s_p_in(), s.s_in_c_itr()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), v.v_wei_os(), s.s_move_slice_k_stride_c(), s.s_gemm_k_num_c(), @@ -1406,7 +1541,7 @@ def move_slice_window_b(): else: with self._deferred_context(): self._emit(m_move_slice_window( - s.s_in_offset(), + s.s_p_in() if self.tunable.tensor_a_pass_through else s.s_in_offset(), v.v_wei_os(), s.s_move_slice_k_stride_c())) return self._get_deferred() @@ -1421,7 +1556,7 @@ def move_slice_window_acc(): with self._deferred_context(): if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: self._emit(m_move_slice_window_accumulate( - s.s_in_offset(), + *(s.s_p_in(), s.s_in_c_itr(), s.s_gemm_k_num_c()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), v.v_in_os(), v.v_in_ihi_list(), v.v_in_iwi_list(), @@ -1440,7 +1575,7 @@ def move_slice_window_acc(): s.s_tmp())) else: self._emit(m_move_slice_window_accumulate( - s.s_in_offset(), + *(s.s_p_in(), s.s_in_c_itr(), s.s_gemm_k_num_c()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), v.v_in_os(), v.v_in_ihi_list(), v.v_in_iwi_list(), @@ -1516,41 +1651,51 @@ def move_slice_window_acc(): fctrl.interleave = self.tunable.fma_interleave # functor - fctrl.global_load_a_functor = self.global_load_wei - fctrl.global_load_b_functor = self.global_load_in - fctrl.shared_store_a_functor = self.shared_store_wei - fctrl.shared_store_b_functor = self.shared_store_in + # fctrl.global_load_a_functor = self.global_load_wei + # fctrl.global_load_b_functor = self.global_load_in + # fctrl.shared_store_a_functor = self.shared_store_wei + # fctrl.shared_store_b_functor = self.shared_store_in + fctrl.global_load_a_functor = self.global_load_in + fctrl.global_load_b_functor = self.global_load_wei + fctrl.shared_store_a_functor = self.shared_store_in + fctrl.shared_store_b_functor = self.shared_store_wei - ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() - fctrl.lds_k_pack = ta_c + # ta_nb0, ta_nb1, ta_e, ta_c, tb_e, tb_c, tb_k0, tb_k1 = self.get_thread_lengths() + fctrl.lds_k_pack = k_pack + + share_load_packed = k_pack if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through else 1 if ctrl_xdlops_mapping.wave_step_m == 1: - fctrl.shared_load_a_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load + fctrl.shared_load_a_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load else: assert ctrl_xdlops_mapping.wave_step_m == 2, "currently only support wave_step_m is 2" - fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ta_c*ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) + fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) if ctrl_xdlops_mapping.wave_step_n == 1: - fctrl.shared_load_b_functor = inst_ds_read_t(data_byte) # xdlops load from LDS always single load + fctrl.shared_load_b_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load else: assert ctrl_xdlops_mapping.wave_step_n == 2, "currently only support wave_step_n is 2" - fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte, ta_c*ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) + fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) fctrl.move_slice_window_a_functor = move_slice_window_a fctrl.move_slice_window_b_functor = move_slice_window_b fctrl.move_slice_window_accumule_functor = move_slice_window_acc if self.tunable.nxe != 0 else None # sympol type - fctrl.v_a = v.v_a - fctrl.v_b = v.v_b + fctrl.v_a = v.v_a if not self.tunable.tensor_a_pass_through else None + fctrl.v_b = v.v_b if not self.tunable.tensor_b_pass_through else None fctrl.a_c = a.a_c fctrl.v_gld_a = v.v_gld_a fctrl.v_gld_b = v.v_gld_b - fctrl.v_sld_a_os = v.v_sld_a_os - fctrl.v_sld_b_os = v.v_sld_b_os - fctrl.v_sst_a_os = v.v_sst_a_os - fctrl.v_sst_b_os = v.v_sst_b_os + fctrl.v_sld_a_os = v.v_sld_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sld_b_os = v.v_sld_b_os if not self.tunable.tensor_b_pass_through else None + fctrl.v_sst_a_os = v.v_sst_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sst_b_os = v.v_sst_b_os if not self.tunable.tensor_b_pass_through else None fctrl.s_kitr = s.s_kitr fctrl.s_knum = s.s_knum + fctrl.pass_through_a = self.tunable.tensor_a_pass_through + fctrl.pass_through_b = self.tunable.tensor_b_pass_through + fctrl.pass_through_v_pack_a = self.get_k_pack() + fctrl.pass_through_v_pack_b = self.get_k_pack() mfma_main_loop = mfma_main_loop_t(self.mc, fctrl) mfma_main_loop.emit() diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 65f0a781..0db0248d 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -30,6 +30,9 @@ from .mfma import * from .xdlops_mapping import * from .nop import * +import re + +MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT = 1 # last wait for ds_read advance a mfma slot class ctrl_mfma_main_loop_t(object): def __init__(self): @@ -73,6 +76,11 @@ def __init__(self): self.lds_pad_m = 0 # pad how many pixels per m row self.lds_pad_n = 0 # pad how many pixels per n row + self.pass_through_a = 0 # a tensor not using LDS + self.pass_through_b = 0 # b tensor not using LDS + self.pass_through_v_pack_a = 1 # passthough tensor may have v pack, indicate vector load + self.pass_through_v_pack_b = 1 + class mfma_main_loop_t(mc_base_t): ''' ''' @@ -80,7 +88,409 @@ def __init__(self, mc, ctrl): mc_base_t.__init__(self, mc) self.ctrl = ctrl assert type(ctrl) is ctrl_mfma_main_loop_t + + def emit_single_pass_through(self): + ''' + one side of A/B tensor not using LDS, used for skinny gemm + a/b -> p/q, where p side passthrough lds, q side is normal + ''' + + p_idx = 0 if self.ctrl.pass_through_a else 1 + q_idx = p_idx ^ 1 + ctrl = self.ctrl + + label_mfma_body = 'L_{}_mfma_body'.format(self.ctrl.label_prefix) + label_mfma_finishing = 'L_{}_mfma_finishing'.format(self.ctrl.label_prefix) + label_mfma_end = 'L_{}_mfma_end'.format(self.ctrl.label_prefix) + + f_gld_p = [ctrl.global_load_a_functor, ctrl.global_load_b_functor][p_idx] + f_gld_q = [ctrl.global_load_a_functor, ctrl.global_load_b_functor][q_idx] + f_sst_p = [ctrl.shared_store_a_functor, ctrl.shared_store_b_functor][p_idx] + f_sst_q = [ctrl.shared_store_a_functor, ctrl.shared_store_b_functor][q_idx] + + f_sld_p = [ctrl.shared_load_a_functor, ctrl.shared_load_b_functor][p_idx] + f_sld_q = [ctrl.shared_load_a_functor, ctrl.shared_load_b_functor][q_idx] + + f_move_slice_window_p = [ctrl.move_slice_window_a_functor, ctrl.move_slice_window_b_functor][p_idx] + f_move_slice_window_q = [ctrl.move_slice_window_a_functor, ctrl.move_slice_window_b_functor][q_idx] + f_move_slice_window_acc = ctrl.move_slice_window_accumule_functor + + v_gld_p = [ctrl.v_gld_a, ctrl.v_gld_b][p_idx] + v_gld_q = [ctrl.v_gld_a, ctrl.v_gld_b][q_idx] + + a_c = ctrl.a_c + v_q = [ctrl.v_a, ctrl.v_b][q_idx] + v_sld_q_os = [ctrl.v_sld_a_os, ctrl.v_sld_b_os][q_idx] + v_sst_q_os = [ctrl.v_sst_a_os, ctrl.v_sst_b_os][q_idx] + + s_kitr = ctrl.s_kitr + s_knum = ctrl.s_knum + cxm = ctrl.cxm + + data_byte = amdgpu_precision_data_byte(ctrl.data_type) + + lds_width_m = data_byte * cxm.wave_tile_m * cxm.wave_step_m * cxm.waves_per_m() * cxm.wave_repeat_m + lds_width_n = data_byte * cxm.wave_tile_n * cxm.wave_step_n * cxm.waves_per_n() * cxm.wave_repeat_n + lds_single_size = ctrl.lds_single_size + + lds_width_q = [lds_width_m, lds_width_n][q_idx] + + # used as offset:x number. may some + lds_base_m = 0 + lds_base_n = 0 + assert ctrl.unroll_k % cxm.block_k() == 0 + unroll_k = ctrl.unroll_k + k_per_inst = cxm.block_k() + + pad_m = ctrl.lds_pad_m + pad_n = ctrl.lds_pad_n + + lds_base_q = [lds_base_m, lds_base_n][q_idx] + pad_q = [pad_m, pad_n][q_idx] + + num_v_p = [cxm.inst_mfma.num_v_a, cxm.inst_mfma.num_v_b][p_idx] + num_v_q = [cxm.inst_mfma.num_v_a, cxm.inst_mfma.num_v_b][q_idx] + wave_step_p = [cxm.wave_step_m, cxm.wave_step_n][p_idx] + wave_step_q = [cxm.wave_step_m, cxm.wave_step_n][q_idx] + wave_repeat_p = [cxm.wave_repeat_m, cxm.wave_repeat_n][p_idx] + wave_repeat_q = [cxm.wave_repeat_m, cxm.wave_repeat_n][q_idx] + + v_pack_p = [ctrl.pass_through_v_pack_a, ctrl.pass_through_v_pack_b][p_idx] + v_pack_q = [ctrl.pass_through_v_pack_a, ctrl.pass_through_v_pack_b][q_idx] + assert v_pack_p == v_pack_q, "currently only support p, q the same" + + assert unroll_k % (v_pack_p * k_per_inst) == 0 + unroll_k_slot = unroll_k // (v_pack_p * k_per_inst) + + def global_load_p(): + with self._deferred_context(): + self._emit(f_gld_p()) + return self._get_deferred() + + def global_load_q(): + with self._deferred_context(): + self._emit(f_gld_q()) + return self._get_deferred() + + def move_slice_window_pq(): + with self._deferred_context(): + if f_move_slice_window_p: + self._emit(f_move_slice_window_p()) + if f_move_slice_window_q: + self._emit(f_move_slice_window_q()) + return self._get_deferred() + + def move_slice_window_acc(): + with self._deferred_context(): + self._emit(f_move_slice_window_acc()) + return self._get_deferred() + + def call_mbb(mbb): + return machine_basic_block_call(self, mbb) + + # parse global load of p tensor into list of single load + mbb_gld_p = create_machine_basic_block(global_load_p()) + mbb_gld_q = create_machine_basic_block(global_load_q()) + + mbb_p_clear = 1 if mbb_gld_p[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 + mbb_q_clear = 1 if mbb_gld_q[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 + + if mbb_p_clear == 1: + # hack on v_clear_nc + v_clear_nc_strs = mbb_gld_p[0].mc_inst(-1).inst_str + v_clear_nc_list = re.split('[,\s]+', v_clear_nc_strs) + assert len(v_clear_nc_list) == 3 and v_clear_nc_list[0] == '.v_clear_nc' + num_gld_p = int(v_clear_nc_list[2]) # TODO: check number + assert num_gld_p % (len(mbb_gld_p) - mbb_p_clear) == 0 + num_gld_p_per_issue = num_gld_p // (len(mbb_gld_p) - mbb_p_clear) + def emit_v_clear_nc_p(i): + with self._deferred_context(): + self._emit(f".v_clear_nc {v_gld_p(i * num_gld_p_per_issue)}, {num_gld_p_per_issue}") + return self._get_deferred() + + mbb_gld_p_wrapper = list() + for i in range(len(mbb_gld_p) - mbb_p_clear): + mbb_gld_p_wrapper += create_machine_basic_block(emit_v_clear_nc_p(i) + '\n' + call_mbb(mbb_gld_p[i+1]), merge_mbb = 1) + + mbb_gld_p = mbb_gld_p_wrapper + mbb_p_clear = 0 + + num_p_issue = len(mbb_gld_p) - mbb_p_clear + num_q_issue = len(mbb_gld_q) - mbb_q_clear + + mbb_msw_pq = create_machine_basic_block(move_slice_window_pq(), merge_mbb = 1) if (f_move_slice_window_p or f_move_slice_window_q) else list() + mbb_msw_acc = create_machine_basic_block(move_slice_window_acc(), merge_mbb = 1) if f_move_slice_window_acc else list() + + def mapped_ioffset(i_k, width_byte, pad_pixel, offset = 0): + k_pack = self.ctrl.lds_k_pack + i_k0 = i_k // k_pack + i_kp = i_k % k_pack + return i_k0 * (width_byte * k_pack + pad_pixel * data_byte) + i_kp * data_byte + offset * k_pack + + def mi_q(i_k, offset = 0): + return mapped_ioffset(i_k, lds_width_q, pad_q, offset) + + def mfma_step_pxq_vk(i_k, i_repeat_p, i_repeat_q, i_v, i_local_buffer_q = 0): + # v_pack is in k direction, hence c_index stay the same across different i_v + mfma = cxm.inst_mfma + num_agpr_per_issue = mfma.num_a_c + with self._deferred_context(): + for i_step_q in range(wave_step_q): + for i_step_p in range(wave_step_p): + if p_idx == 0: + c_index = i_repeat_p * wave_step_p * wave_step_q * wave_repeat_q * num_agpr_per_issue + \ + i_repeat_q * wave_step_p * wave_step_q * num_agpr_per_issue + \ + i_step_p * wave_step_q * num_agpr_per_issue + \ + i_step_q * num_agpr_per_issue + else: + c_index = i_repeat_q * wave_step_q * wave_step_p * wave_repeat_p * num_agpr_per_issue + \ + i_repeat_p * wave_step_q * wave_step_p * num_agpr_per_issue + \ + i_step_q * wave_step_p * num_agpr_per_issue + \ + i_step_p * num_agpr_per_issue + c_index_end = c_index + num_agpr_per_issue - 1 + + p_index = i_k * wave_repeat_p * wave_step_p * v_pack_p * num_v_p + \ + i_repeat_p * wave_step_p * v_pack_p * num_v_p + \ + i_step_p * v_pack_p * num_v_p + \ + i_v * num_v_p + + q_index = i_local_buffer_q * wave_step_q * wave_repeat_q * v_pack_q * num_v_q + \ + i_repeat_q * wave_step_q * v_pack_q * num_v_q + \ + i_step_q * v_pack_q * num_v_q + \ + i_v * num_v_q + self._emit(mfma(a_c((c_index, c_index_end)), v_gld_p(p_index), v_q(q_index), a_c((c_index, c_index_end))) + f" ; repeat:{i_repeat_p}x{i_repeat_q}, step:{i_step_p}x{i_step_q}, k:{i_k}, v:{i_v}, num_a_c:{num_agpr_per_issue}") + return self._get_deferred() + + def mfma_loop(): + mfma = cxm.inst_mfma + + repeat_q_thread_offset = wave_step_q * num_v_q * v_pack_p + local_buffer_q = wave_repeat_q * repeat_q_thread_offset + mfma_v_pack_slot = unroll_k_slot * wave_repeat_p * wave_repeat_q # TODO: not consider step + cnt_mfma_v_pack_slot = 0 + + def first_sld(): + # when start of mfma main loop, do this load + with self._deferred_context(): + for i in range(wave_repeat_q): + self._emit(f_sld_q(v_q(i * repeat_q_thread_offset), v_sld_q_os(), lds_base_q + mi_q(0, i * (lds_width_q // 2)))) + if ctrl.local_prefetch_num == 2: + # always load a single piece of repeat + self._emit(f_sld_q(v_q(wave_step_q * wave_repeat_q * v_pack_q * num_v_q), v_sld_q_os(), lds_base_q + mi_q(1 * v_pack_p * k_per_inst, 0))) + return self._get_deferred() + + mbb_first_sld = create_machine_basic_block(first_sld()) + + def mfma_per_k_slot(i_k, i_mfma_v_pack_slot, is_last_fma): + ''' + k slot is unroll_k / k_per_inst + pattern: + prefetch:1, repeat:1 (phase:1) + 0 0 0 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 1 0 + 1 0 0 0 2 0 + 2 0 0 0 3 0 + 3 0 0 0 4 0 x + + prefetch:2, repeat:1 (phase:1) + 0 0 0 + 0 1 1 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 2 1 + 1 0 0 1 3 1 + 2 0 0 0 4 1 x + 3 0 0 1 5 0 x + + prefetch:1, repeat:2 (phase:2) + 0 0 0 + 1 0 0 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 0 0 1 1 + 0 1 1 0 1 1 + 1 0 0 0 2 1 + 1 1 1 0 2 1 + 2 0 0 0 3 1 + 2 1 0 0 3 1 + 3 0 1 0 4 1 x + 3 1 1 0 4 0 x + + prefetch:2, repeat:2 (phase:3) + 0 0 0 + 1 0 0 + 0 1 1 + i_k i_r load_i_r load_i_buf load_i_k lgkmcnt need_load + 0 0 1 1 1 2 + 0 1 0 0 2 2 + 1 0 1 0 2 2 + 1 1 0 1 3 2 + 2 0 1 1 3 2 + 2 1 0 0 4 2 x + 3 0 1 0 4 1 x + 3 1 0 1 5 0 x + ''' + pref = ctrl.local_prefetch_num + rept = wave_repeat_q + phase = pref + rept - 1 # idx before entering main loop + + i_r_sequence = [ x & (rept - 1) for x in range(pref * rept)] + i_b_sequence = [(x >> (rept - 1)) & (pref - 1) for x in range(pref * rept)] + + i_local_buffer_q = i_k & 1 if pref == 2 else 0 + i_k_sst_q = i_k == (unroll_k_slot - ctrl.local_prefetch_num) + # print(f"i_k:{i_k}, i_k_sst_q:{i_k_sst_q}") + gld_p_per_k = wave_repeat_p * wave_step_p + cnt_mfma = 0 + def try_do_gld_per_slot(i_slot): + if is_last_fma: + if i_k == 0: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + else: + mbb_gld_p_per_k = list() + mbb_gld_per_k = mbb_gld_p_per_k + else: + if i_k == 0: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + else: + start_p_idx = mbb_p_clear if i_k == 1 else ((i_k - 1) * gld_p_per_k + mbb_p_clear) # always no clear + mbb_gld_p_per_k = mbb_gld_p[start_p_idx : i_k * gld_p_per_k + mbb_p_clear ] + mbb_gld_per_k = mbb_gld_p_per_k + mbb_msw_pq + mbb_msw_acc + mbb_gld_q if i_k == 0 else mbb_gld_p_per_k + num_gld_slot_per_k = wave_repeat_p * wave_repeat_q * v_pack_p + num_gld_per_slot = utility_next_mul(len(mbb_gld_per_k), num_gld_slot_per_k) // num_gld_slot_per_k + for i_gld in range(num_gld_per_slot): + current_gld = i_slot * num_gld_per_slot + i_gld + if current_gld < len(mbb_gld_per_k): + self._emit(call_mbb(mbb_gld_per_k[current_gld])) + + def do_sst_q(): + # print(f"do_sst_q, i_k:{i_k}") + if ctrl.lds_buffer_num == 1: + self._emit(f"s_barrier") + self._emit(f_sst_q()) + if ctrl.lds_buffer_num != 1: + self._emit(f"v_xor_b32 v[{v_sst_q_os()}], {hex(lds_single_size)}, v[{v_sst_q_os()}]") + + def do_sld_q(i_v, i_r): + # interleave into different v_pack + i_idx = i_k * rept + i_r + i_idx_mod = (i_idx + phase) % (pref * rept) + i_idx_int = (i_idx + phase) // (pref * rept) + + # print(f" ==i_r_sequence:{i_r_sequence}, i_b_sequence:{i_b_sequence}, i_idx:{i_idx}, mod:{i_idx_mod}, int:{i_idx_int}") + + load_i_r = i_r_sequence[i_idx_mod] + load_i_b = i_b_sequence[i_idx_mod] + load_i_k = i_idx_int * pref + load_i_b + + if i_v == (v_pack_p - 1) and load_i_k < unroll_k_slot: + the_str = f' ; i_r:{load_i_r}, i_b:{load_i_b}, i_k:{load_i_k}' + self._emit(f_sld_q(v_q(load_i_b * local_buffer_q + load_i_r * repeat_q_thread_offset), v_sld_q_os(), lds_base_q + mi_q(load_i_k * v_pack_p * k_per_inst, load_i_r * (lds_width_q // 2) )) + the_str) + + + if i_k == 0: + for mbb_1st in mbb_first_sld[1:]: + self._emit(call_mbb(mbb_1st)) + + for i_rp in range(wave_repeat_p): + # cnt_p_load = cnt_p_load + 1 + for i_rq in range(wave_repeat_q): + if i_rq != 0: + vmcnt_str = "" + else: + if i_k == 0: + vmcnt_str = f'vmcnt({num_p_issue - 1 - gld_p_per_k})' + else: + if not is_last_fma: + vmcnt_str = f'vmcnt({num_p_issue + num_q_issue - 2})' + else: + vmcnt_str = f'vmcnt({num_p_issue - i_k - 1})' + num_lgkmcnt = (pref + rept - 2) - ((pref - 1 + i_rq) if i_k == (unroll_k_slot-1) else 0) + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and num_lgkmcnt == 0: + # we need a change to put last lgkmcnt earlier + assert vmcnt_str == "" + if is_last_fma: + self._emit(f's_waitcnt lgkmcnt(0)') + else: + self._emit(f's_waitcnt lgkmcnt({num_lgkmcnt}) {vmcnt_str}') + + for i_v in range(v_pack_p): + self._emit(mfma_step_pxq_vk(i_k, i_rp, i_rq, i_v, i_local_buffer_q)) + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT: + if (i_mfma_v_pack_slot == mfma_v_pack_slot - 2) and (v_pack_p == 1 or i_v == (v_pack_p // 2) - 1): + assert i_rq == 0 + if not is_last_fma: + self._emit(f's_waitcnt lgkmcnt(0) vmcnt({num_p_issue - gld_p_per_k})') + do_sst_q() + do_sld_q(i_v, i_rq) # will not emit when last ds wait, hence will never co-exist when last ds wait emit + #if not is_last_fma: + try_do_gld_per_slot(cnt_mfma) + cnt_mfma = cnt_mfma + 1 + assert i_mfma_v_pack_slot < mfma_v_pack_slot, f'i_mfma_v_pack_slot:{i_mfma_v_pack_slot}, mfma_v_pack_slot:{mfma_v_pack_slot}' + i_mfma_v_pack_slot = i_mfma_v_pack_slot + 1 + + if not is_last_fma and i_k == (unroll_k_slot - 1): + self._emit(f's_waitcnt lgkmcnt(0)') + self._emit(f"s_barrier") + self._emit(call_mbb(mbb_first_sld[0])) + self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_kitr()}], {unroll_k}") + self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") + self._emit(f"s_cbranch_scc1 {label_mfma_body}") + return i_mfma_v_pack_slot + + self._emit(call_mbb(mbb_first_sld[0])) + self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_knum()}], {unroll_k}") + self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") + self._emit(f"s_cbranch_scc0 {label_mfma_end}") + self._emit_empty_line() + + self._emit_front(f"{label_mfma_body}:") + self._emit(f"; do fma accumulate with unroll {unroll_k}, mfma_v_pack_slot:{mfma_v_pack_slot}") + + for i_k in range(unroll_k_slot): + cnt_mfma_v_pack_slot = mfma_per_k_slot(i_k, cnt_mfma_v_pack_slot, False) + + self._emit_front(f"{label_mfma_end}:") + cnt_mfma_v_pack_slot = 0 + for i_k in range(unroll_k_slot): + cnt_mfma_v_pack_slot = mfma_per_k_slot(i_k, cnt_mfma_v_pack_slot, True) + + # start emit, first load q tensor, then p tensor. + self._emit(f"; start MFMA loop, wave tile:{cxm.wave_tile_m}x{cxm.wave_tile_n}, repeat:{cxm.wave_repeat_m}x{cxm.wave_repeat_n}, step:{cxm.wave_step_m}x{cxm.wave_step_n}" +\ + f", k_pack:{self.ctrl.lds_k_pack}, p_issue:{num_p_issue}, q_issue:{num_q_issue}, local_prefetch_num:{ctrl.local_prefetch_num}") + + self._emit(f".v_clear_acc_c {a_c()}, {cxm.total_acc_c()}") + # self._emit(f"; make sure acc WAR harzard, at least 1 nop for src_c") + + self._emit(f"s_waitcnt vmcnt({f_gld_p.get_issues() - wave_repeat_p * wave_step_p})") + self._emit(f_sst_q()) + self._emit_empty_line() + + # decrese k + # self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_knum()}], {unroll_k}") + # self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") + # self._emit(f"s_cbranch_scc0 {label_mfma_end}") + # self._emit_empty_line() + + # right after clear acc + # self._emit(f_move_slice_window_p()) + # self._emit(f_move_slice_window_q()) + # if f_move_slice_window_acc != None: + # self._emit(f_move_slice_window_acc()) + + self._emit(f"s_waitcnt lgkmcnt(0)") + self._emit(f"s_barrier") + self._emit_empty_line() + + mfma_loop() + + nop = emit_nop_t(self.mc) + nop(cxm.inst_mfma.get_nop_count_mfma_acc_raw()) # solve dependency + + def emit(self): + if self.ctrl.pass_through_a ^ self.ctrl.pass_through_b: + return self.emit_single_pass_through() + label_mfma_body = 'L_{}_mfma_body'.format(self.ctrl.label_prefix) label_mfma_finishing = 'L_{}_mfma_finishing'.format(self.ctrl.label_prefix) label_mfma_end = 'L_{}_mfma_end'.format(self.ctrl.label_prefix) @@ -241,7 +651,6 @@ def do_unroll_k_1x1_sub(): self._emit_front(f"{label_mfma_finishing}:") self._emit(mfma_step_mxn(0, 0, 1, 1)) - self._emit_front(f"{label_mfma_end}:") self._emit("s_waitcnt lgkmcnt(0)") self._emit("s_barrier") diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index b9f44711..cb21c84e 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -288,6 +288,7 @@ def serialize(self): ctrl_xdlops_mapping_t( 64 , 128, 64, 32, 1, 4, 1, 1, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 4, 1, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), + ctrl_xdlops_mapping_t( 128, 64 , 64, 32, 1, 2, 1, 2, 1, 1, v_mfma_f32_32x32x1f32), ctrl_xdlops_mapping_t( 64 , 128, 32, 32, 2, 2, 2, 2, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 1, 2, 2, 2, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), @@ -308,6 +309,7 @@ def serialize(self): ctrl_xdlops_mapping_t( 16 , 128, 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), # need re-design coalescing. or do irregular gemm ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 64 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), @@ -376,17 +378,24 @@ def __init__(self, mc, ctrl): mc_base_t.__init__(self, mc) assert type(ctrl) is ctrl_xdlops_mapping_t self.ctrl = ctrl - def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4, k_pack = 1): + def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp4, **options): ''' notice! this is to calculate LDS offset for A/B matrix input, it is not the same as C matrix output layout, due to xdlops C matrix output describe is in coalescint_store ''' + def get_dict_with_default(some_dict, key, default_value): + if key in some_dict: + return some_dict[key] + return default_value ctrl = self.ctrl #print(f"ctrl.block_n()={ctrl.block_n()}, ctrl.block_m()={ctrl.block_m()}") #print(f"ctrl.block_n_per_wave()={ctrl.block_n_per_wave()}, ctrl.block_m_per_wave()={ctrl.block_m_per_wave()}") assert ctrl.block_n() == ctrl.block_m() and ctrl.block_k() * ctrl.block_n() * ctrl.block_n_per_wave() * ctrl.block_m_per_wave() == AMDGPU_WAVE_SIZE + k_pack = get_dict_with_default(options, "k_pack", 1) + v_pack = get_dict_with_default(options, "v_pack", 1) + assert v_pack in (1, k_pack), 'currently only support v_pack is 1 or k_pack' with self._deferred_context(): - self._emit(f"; xdlops mapping, get source matrix gemm index, k_pack:{k_pack}") + self._emit(f"; xdlops mapping, get source matrix gemm index, k_pack:{k_pack}, v_pack:{v_pack}") self._emit(f"v_and_b32 v[{v_gemm_in}], {ctrl.block_n() - 1}, v[{v_thread_id}] ; block_n index ") self._emit(f"v_and_b32 v[{v_gemm_im}], {ctrl.block_m() - 1}, v[{v_thread_id}] ; block_m index ") if k_pack != 1: @@ -397,12 +406,16 @@ def get_gemm_index_for_src_matrix(self, v_gemm_in, v_gemm_im, v_thread_id, v_tmp if ctrl.block_k() != 1: self._emit(f"v_and_b32 v[{v_tmp4} + 0], {ctrl.block_k() - 1}, v[{v_thread_id}] ; block_k_per_wave index") if k_pack != 1: - self._emit(f"v_and_b32 v[{v_tmp4} + 1], {k_pack - 1}, v[{v_tmp4} + 0] ; and k_pack:{k_pack}") - self._emit(f"v_lshrrev_b32 v[{v_tmp4} + 0], {utility_log2(k_pack)}, v[{v_tmp4} + 0] ; shift right k_pack:{k_pack}") - self._emit(f"v_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 1], v[{v_gemm_in}] ; or k_pack:{k_pack}") - self._emit(f"v_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], v[{v_gemm_im}] ; or k_pack:{k_pack}") - self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") - self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") + if v_pack == 1: + self._emit(f"v_and_b32 v[{v_tmp4} + 1], {k_pack - 1}, v[{v_tmp4} + 0] ; and k_pack:{k_pack}") + self._emit(f"v_lshrrev_b32 v[{v_tmp4} + 0], {utility_log2(k_pack)}, v[{v_tmp4} + 0] ; shift right k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 1], v[{v_gemm_in}] ; or k_pack:{k_pack}") + self._emit(f"v_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 1], v[{v_gemm_im}] ; or k_pack:{k_pack}") + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") + else: + self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n * k_pack)}, v[{v_gemm_in}]") + self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m * k_pack)}, v[{v_gemm_im}]") else: self._emit(f"v_lshl_or_b32 v[{v_gemm_in}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_n)}, v[{v_gemm_in}]") self._emit(f"v_lshl_or_b32 v[{v_gemm_im}], v[{v_tmp4} + 0], {utility_log2(ctrl.macro_tile_m)}, v[{v_gemm_im}]") diff --git a/igemm/codegen/mbb.py b/igemm/codegen/mbb.py index 929434b3..bedf58d0 100644 --- a/igemm/codegen/mbb.py +++ b/igemm/codegen/mbb.py @@ -74,6 +74,8 @@ def get_mc_inst_type(inst_str): return MC_INST_TYPE_SHARE_MEM if mc_inst_is_global_mem(inst_op): return MC_INST_TYPE_GLOBAL_MEM + if mc_inst_is_legacy_macro(inst_op): + return MC_INST_TYPE_LEGACY_MACRO return MC_INST_TYPE_OTHER class mc_inst_t(object): @@ -156,6 +158,7 @@ def create_machine_basic_block(multi_line_inst_str, **option): option: group_mbb_by_end_of_inst_op : str, group several mc_inst into mbb, each mbb is by end of this value + merge_mbb : int, do not split into multiple mbb ''' class parse_mbb_list_t(object): STATE_NORMAL = 0 @@ -236,6 +239,8 @@ def get_dict_with_default(dictionary, key, default_value): return dictionary[key] else: return default_value + + merge_mbb = get_dict_with_default(option, "merge_mbb", 0) group_mbb_by_end_of_inst_op = get_dict_with_default(option, "group_mbb_by_end_of_inst_op", "") def match_group_mbb_by_end_of_inst_op(inst_op): @@ -258,11 +263,17 @@ def match_group_mbb_by_end_of_inst_op(inst_op): if len(istrs) == 0: return None + for i, istr in enumerate(istrs): mc_inst = create_mc_inst(istr) if not mc_inst: continue + # merge every string into a single mbb + if merge_mbb: + mc_inst_buffer.append(mc_inst) + continue + # early pass rule if self.is_mbb_start_macro_c_clear(i, istrs): ''' diff --git a/igemm/codegen/scheduler.py b/igemm/codegen/scheduler.py index a06790b0..7e588b2f 100644 --- a/igemm/codegen/scheduler.py +++ b/igemm/codegen/scheduler.py @@ -96,7 +96,7 @@ def mbb_is_macro_c_clear(mbb): ''' if mbb.length() == 1: if mbb.mc_inst().type() == MC_INST_TYPE_LEGACY_MACRO: - if get_mc_inst_op(mbb.mc_inst()).startswith('.v_clear_nc'): + if get_mc_inst_op(mbb.mc_inst().inst_str).startswith('.v_clear_nc'): return True return False @@ -138,7 +138,7 @@ def mbb_is_macro_c_clear(mbb): #else: # break assert num_gmem != 0, f"no global mem in this instructino list, please check" - assert num_v_c_clear in (0, 1) + # assert num_v_c_clear in (0, 1) num_gmem += num_v_c_clear # second decide how many global mem to interleave per interval From ba431746a9cebedf5ebf19ff94bc32f5c8d94aec Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 2 Mar 2021 13:40:19 +0800 Subject: [PATCH 24/40] update code --- config/igemm_fwd_gtc_gfx908_nhwc.config | 620 +++++++++++++++++++++++- igemm/algo/global_memory.py | 32 +- igemm/algo/igemm_base.py | 12 +- igemm/algo/igemm_fwd_gtc_nhwc.py | 34 +- igemm/algo/mfma_main_loop.py | 106 ++-- igemm/algo/xdlops_mapping.py | 2 + 6 files changed, 739 insertions(+), 67 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index e24c3ea6..2cb3b0a5 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -135,24 +135,632 @@ tensor_layout = 'nhwc' nxb = 0 nxe = 1 -# #--------------------------- 128x128 + + +#--------------------------- 64x128 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 1, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +# #--------------------------- 128x64 # [igemm_fwd_gtc] # gemm_m_per_block = 128 -# gemm_n_per_block = 128 +# gemm_n_per_block = 64 # gemm_k_per_block = 32 # wave_tile_m = 32 # wave_step_m = 1 +# wave_repeat_m = 1 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_pass_through = 1 +# tensor_a_thread_lengths = [1,16, 1, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + + +# #--------------------------- 128x64 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 64 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 # wave_repeat_m = 2 # wave_tile_n = 32 # wave_step_n = 1 # wave_repeat_n = 2 # wave_tile_k = 2 # tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 -# tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 -# tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 -# tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0XK1 +# tensor_a_cluster_lengths = [1, 4, 1, 32] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 32] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 256x64 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x32 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 32 +gemm_k_per_block = 32 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 8, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + + +#--------------------------- 64x64 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 32] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 8 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 2, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 64x16 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 16 +gemm_k_per_block = 4 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 1, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 32x16 +[igemm_fwd_gtc] +gemm_m_per_block = 32 +gemm_n_per_block = 16 +gemm_k_per_block = 16 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 2, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 2, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 32x16 +[igemm_fwd_gtc] +gemm_m_per_block = 32 +gemm_n_per_block = 16 +gemm_k_per_block = 32 +wave_tile_m = 16 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 16 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 4 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +# #--------------------------- 32x16 +# [igemm_fwd_gtc] +# gemm_m_per_block = 32 +# gemm_n_per_block = 16 +# gemm_k_per_block = 8 +# wave_tile_m = 16 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 16 +# wave_step_n = 1 +# wave_repeat_n = 1 +# wave_tile_k = 1 +# tensor_a_thread_lengths = [1, 1, 2, 1] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 8, 1, 16] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 8, 1, 16] # ExCxK0xK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 + + +#--------------------------- 64x4 +[igemm_fwd_gtc] +gemm_m_per_block = 64 +gemm_n_per_block = 4 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 4 +wave_step_n = 1 +wave_repeat_n = 1 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 1,16, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1,16, 1, 4] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 1, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1,16, 1, 4] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +# #---------------------------------------------------------------------------------- +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 2 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 4] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 1 +# +# +# +# #--------------------------- 256x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 256 +# gemm_n_per_block = 128 +# gemm_k_per_block = 8 +# wave_tile_m = 64 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 2, 1, 128] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 2, 1, 128] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 128x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +# direction = "fwd" +# precision = "fp32" +# tensor_layout = 'nhwc' +# nxb = 0 +# nxe = 0 +# +# #--------------------------- 128x128 +# [igemm_fwd_gtc] +# gemm_m_per_block = 128 +# gemm_n_per_block = 128 +# gemm_k_per_block = 16 +# wave_tile_m = 32 +# wave_step_m = 1 +# wave_repeat_m = 2 +# wave_tile_n = 32 +# wave_step_n = 1 +# wave_repeat_n = 2 +# wave_tile_k = 2 +# tensor_a_thread_lengths = [1, 4, 1, 2] # ExCxNB0xNB1 +# tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +# tensor_b_thread_lengths = [1, 4, 1, 2] # ExCxK0xK1 +# tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 # direction = "fwd" # precision = "fp32" # tensor_layout = 'nhwc' # nxb = 0 -# nxe = 1 \ No newline at end of file +# nxe = 1 +# \ No newline at end of file diff --git a/igemm/algo/global_memory.py b/igemm/algo/global_memory.py index 613f3a7b..49a9d883 100755 --- a/igemm/algo/global_memory.py +++ b/igemm/algo/global_memory.py @@ -114,7 +114,7 @@ def __init__(self): # 2: d0 use vgpr precache, d1 use vgpr precache # 3: d0 use sgpr precache, d1 use sgpr precache # 4: .... maybe consider not using precache? - + self.flag_merge_v = 0 # when flag on v_offset, flag and multiple load, or flag per load class macro_igemm_2d_global_load_t(macro_base_t): @@ -560,15 +560,27 @@ def expr(self): if ctrl.src_order == 0 and ctrl.dst_order == 0: i_dst = 0 - for i_d0 in range(ctrl.length_d0): - for i_d1 in range(n_d1): - if ctrl.use_flag and self.v_flag != None: - self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d1)}]") - current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) - self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) - if ctrl.use_flag and self.v_flag != None: - self._emit(f"s_mov_b64 exec, -1") - i_dst = i_dst + 1 + if ctrl.flag_merge_v and n_d1 == 1: + # v is along d1 dimension, hence only possible when n_d1 is 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag()}]") + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(1): + current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) + i_dst = i_dst + 1 + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + else: + for i_d0 in range(ctrl.length_d0): + for i_d1 in range(n_d1): + if ctrl.use_flag and self.v_flag != None: + self._emit(f"v_cmpx_le_u32 vcc, 1, v[{self.v_flag(i_d1)}]") + current_s_offset = 0 if i_d0 == 0 else (self.s_stride_d1() if i_d0 == 1 else self.s_offset(i_d0 - 2)) + self._emit(buffer_load_dword(f"{self.v_dst()}+{i_dst*ctrl.vector_d1}", f"{self.v_os(i_d1)}", f"{self.s_ptr()}", current_s_offset, 0)) + if ctrl.use_flag and self.v_flag != None: + self._emit(f"s_mov_b64 exec, -1") + i_dst = i_dst + 1 else: assert False diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index b9f48444..364d96b2 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -189,7 +189,7 @@ def __init__(self, tunable_dict): default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if (self.direction == 'fwd' and self.tensor_layout == 'nchw') \ else IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N - self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', default_source_access_order) + self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M) self.gemm_m_unmerge_cluster = utility_dict_with_default_t(tunable_dict)('gemm_m_unmerge_cluster', 0) self.gemm_n_unmerge_cluster = utility_dict_with_default_t(tunable_dict)('gemm_n_unmerge_cluster', 0) @@ -255,6 +255,9 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.unmerge_sub_k = 1 self.unmerge_sub_c = self.gemm_n_per_block + self.tensor_a_pass_through_interleave_gld = 0 if self.tensor_layout == 'nhwc' else 1 + self.tensor_b_pass_through_interleave_gld = 0 if self.tensor_layout == 'nhwc' else 1 + self.fma_interleave = IGEMM_GTC_FEAT_FMA_INTERLEAVE self.local_prefetch_num = 1 # vector global/lds implicit here @@ -268,7 +271,7 @@ def _unmerge_x1_from_e(unroll_k, nxe): elif self.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.local_prefetch_num = 2 if IGEMM_GTC_FEAT_LOCAL_PREFETCH else 1 - if (self.tensor_a_pass_through and self.wave_repeat_n) or (self.tensor_b_pass_through and self.wave_repeat_m): + if (self.tensor_a_pass_through and self.wave_repeat_n == 2) or (self.tensor_b_pass_through and self.wave_repeat_m == 2): self.local_prefetch_num = 1 # register for a,b,c buffer xdlops_mapping = get_ctrl_xdlops_mapping_fp32(self.gemm_m_per_block, self.gemm_n_per_block, self.block_size // amdgpu_wave_size(tunable_dict['arch'])) @@ -277,6 +280,9 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.num_vgpr_accumulate_a = self.wave_step_m * self.wave_repeat_m * xdlops_mapping.inst_mfma.num_v_a * self.local_prefetch_num self.num_vgpr_accumulate_b = self.wave_step_n * self.wave_repeat_n * xdlops_mapping.inst_mfma.num_v_b * self.local_prefetch_num + self.global_prefetch_a_num = 2 if self.tensor_a_pass_through and not self.tensor_a_pass_through_interleave_gld else 1 + self.global_prefetch_b_num = 2 if self.tensor_b_pass_through and not self.tensor_b_pass_through_interleave_gld else 1 + self.num_vgpr_global_load_a = igemm_flatten_list_product(self.tensor_a_thread_lengths) self.num_vgpr_global_load_b = igemm_flatten_list_product(self.tensor_b_thread_lengths) @@ -380,6 +386,8 @@ def to_dict(self): tunable_dict['precache_soffset'] = self.precache_soffset tunable_dict['local_prefetch_num'] = self.local_prefetch_num + tunable_dict['global_prefetch_a_num'] = self.global_prefetch_a_num + tunable_dict['global_prefetch_b_num'] = self.global_prefetch_b_num tunable_dict['fma_interleave'] = self.fma_interleave tunable_dict['gemm_m_unmerge_cluster'] = self.gemm_m_unmerge_cluster diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index c8410272..4a382b71 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -36,6 +36,7 @@ from .mfma_main_loop import * IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG = 0 +# IGEMM_FWD_GTC_NHWC_P_INTERLEAVE_GLD = False # p tensor interleave def _find_non_1_index_in_list(list_object): result_list = list() @@ -345,9 +346,9 @@ def __call__(self): self._emit(f"; load input, nxe:{self.outer.tunable.nxe}") #if self.outer.tunable.nxe != 0: # if tunable.tensor_a_pass_through: - self._emit(f".v_clear_nc {v.v_gld_a()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") + self._emit(f".v_clear_nc {v.v_gld_a() if tunable.global_prefetch_a_num == 1 else v.v_gld_a_gpf()}, {m_in_2d_global_load.ctrl.length_d0 * m_in_2d_global_load.ctrl.length_d1}") if tunable.tensor_a_pass_through: - self._emit(m_in_2d_global_load(v.v_gld_a(), s.s_p_in(), v.v_in_os(), None, s.s_in_stride_k_pack(), s.s_in_offset(), + self._emit(m_in_2d_global_load(v.v_gld_a() if tunable.global_prefetch_a_num == 1 else v.v_gld_a_gpf(), s.s_p_in(), v.v_in_os(), None, s.s_in_stride_k_pack(), s.s_in_offset(), *(v.v_in_flag(), v.v_tmp()) if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG else (v.v_in_flag(),))) else: if IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG: @@ -588,7 +589,11 @@ def __init__(self, mc, outer): if not outer.tunable.tensor_b_pass_through: self.v_b = sym_t("v_b" ,vseq(num_vgpr_acc_b)) self.v_gld_a = sym_t("v_gld_a" ,vseq(outer.tunable.num_vgpr_global_load_a)) + if outer.tunable.global_prefetch_a_num == 2: + self.v_gld_a_gpf = sym_t("v_gld_a_gpf" ,vseq(outer.tunable.num_vgpr_global_load_a)) self.v_gld_b = sym_t("v_gld_b" ,vseq(outer.tunable.num_vgpr_global_load_b)) + if outer.tunable.global_prefetch_b_num == 2: + self.v_gld_b_gpf = sym_t("v_gld_b_gpf" ,vseq(outer.tunable.num_vgpr_global_load_b)) if not outer.tunable.tensor_a_pass_through: self.v_sst_a_os = sym_t("v_sst_a_os" ,vseq(1)) self.v_sld_a_os = sym_t("v_sld_a_os" ,vseq(1)) @@ -772,6 +777,7 @@ def get_macro_global_load(self): ctrl_wei_gld.length_d0 = tb_k0 if tb_k0 != 1 else tb_k1 ctrl_wei_gld.length_d1 = tb_c ctrl_wei_gld.vector_d1 = self.get_k_pack() + ctrl_wei_gld.flag_merge_v = 0 if self.tunable.tensor_b_pass_through_interleave_gld else 1 else: if self.wei_thread_copy_ndim == 2: ctrl_wei_gld.length_d0 = wei_thread_copy_dims[wei_thread_copy_index[0]] @@ -787,6 +793,7 @@ def get_macro_global_load(self): ctrl_in_gld.length_d0 = ta_c // self.get_k_pack() ctrl_in_gld.length_d1 = (ta_nb0 if ta_nb0 != 1 else ta_nb1) * self.get_k_pack() ctrl_in_gld.vector_d1 = self.get_k_pack() + ctrl_in_gld.flag_merge_v = 0 if self.tunable.tensor_a_pass_through_interleave_gld else 1 else: if self.in_thread_copy_ndim == 2: ctrl_in_gld.length_d0 = in_thread_copy_dims[in_thread_copy_index[0]] @@ -1323,7 +1330,7 @@ def calculate_and_load_input(): # load in self._emit(f"s_mov_b32 s[{s.s_p_in(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_in(3)}], 0x27000") - if self.tunable.tensor_a_pass_through: + if self.tunable.tensor_a_pass_through and self.tunable.tensor_a_pass_through_interleave_gld: mbb_gld_in = create_machine_basic_block(self.global_load_in()) gld_per_k = self.tunable.wave_repeat_m * self.tunable.wave_step_m for i_mbb in mbb_gld_in[0:(-1 * gld_per_k)]: @@ -1373,7 +1380,7 @@ def calculate_and_load_weight(): self._emit(f".v_clear_nc {v.v_gld_b()}, {m_wei_2d_global_load.ctrl.length_d0 * m_wei_2d_global_load.ctrl.length_d1}") self._emit(f"s_mov_b32 s[{s.s_p_wei(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_wei(3)}], 0x27000") - if self.tunable.tensor_b_pass_through: + if self.tunable.tensor_b_pass_through and self.tunable.tensor_b_pass_through_interleave_gld: mbb_gld_wei = create_machine_basic_block(self.global_load_wei()) gld_per_k = self.tunable.wave_repeat_n * self.tunable.wave_step_n for i_mbb in mbb_gld_wei[0:(-1 * gld_per_k)]: @@ -1686,16 +1693,23 @@ def move_slice_window_acc(): fctrl.a_c = a.a_c fctrl.v_gld_a = v.v_gld_a fctrl.v_gld_b = v.v_gld_b - fctrl.v_sld_a_os = v.v_sld_a_os if not self.tunable.tensor_a_pass_through else None - fctrl.v_sld_b_os = v.v_sld_b_os if not self.tunable.tensor_b_pass_through else None - fctrl.v_sst_a_os = v.v_sst_a_os if not self.tunable.tensor_a_pass_through else None - fctrl.v_sst_b_os = v.v_sst_b_os if not self.tunable.tensor_b_pass_through else None + fctrl.v_gld_a_gpf = v.v_gld_a_gpf if self.tunable.global_prefetch_a_num == 2 else None + fctrl.v_gld_b_gpf = v.v_gld_b_gpf if self.tunable.global_prefetch_b_num == 2 else None + fctrl.v_gld_a_num = self.tunable.num_vgpr_global_load_a + fctrl.v_gld_b_num = self.tunable.num_vgpr_global_load_b + fctrl.v_sld_a_os = v.v_sld_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sld_b_os = v.v_sld_b_os if not self.tunable.tensor_b_pass_through else None + fctrl.v_sst_a_os = v.v_sst_a_os if not self.tunable.tensor_a_pass_through else None + fctrl.v_sst_b_os = v.v_sst_b_os if not self.tunable.tensor_b_pass_through else None fctrl.s_kitr = s.s_kitr fctrl.s_knum = s.s_knum fctrl.pass_through_a = self.tunable.tensor_a_pass_through fctrl.pass_through_b = self.tunable.tensor_b_pass_through - fctrl.pass_through_v_pack_a = self.get_k_pack() - fctrl.pass_through_v_pack_b = self.get_k_pack() + fctrl.pass_through_a_v_pack = self.get_k_pack() + fctrl.pass_through_b_v_pack = self.get_k_pack() + + fctrl.pass_through_a_interleave_gld = 1 if self.tunable.tensor_a_pass_through_interleave_gld else 0 + fctrl.pass_through_b_interleave_gld = 1 if self.tunable.tensor_b_pass_through_interleave_gld else 0 mfma_main_loop = mfma_main_loop_t(self.mc, fctrl) mfma_main_loop.emit() diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 0db0248d..66b5e34e 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -63,7 +63,11 @@ def __init__(self): self.v_b = None self.a_c = None self.v_gld_a = None + self.v_gld_a_gpf = None # used for a pass through and not interleaved, as global prefetch register + self.v_gld_a_num = 1 self.v_gld_b = None + self.v_gld_b_gpf = None # used for b pass through and not interleaved, as global prefetch register + self.v_gld_b_num = 1 self.v_sld_a_os = None self.v_sld_b_os = None self.v_sst_a_os = None @@ -78,8 +82,10 @@ def __init__(self): self.pass_through_a = 0 # a tensor not using LDS self.pass_through_b = 0 # b tensor not using LDS - self.pass_through_v_pack_a = 1 # passthough tensor may have v pack, indicate vector load - self.pass_through_v_pack_b = 1 + self.pass_through_a_v_pack = 1 # passthough tensor may have v pack, indicate vector load + self.pass_through_b_v_pack = 1 + self.pass_through_a_interleave_gld = 1 + self.pass_through_b_interleave_gld = 1 class mfma_main_loop_t(mc_base_t): ''' @@ -118,6 +124,9 @@ def emit_single_pass_through(self): v_gld_p = [ctrl.v_gld_a, ctrl.v_gld_b][p_idx] v_gld_q = [ctrl.v_gld_a, ctrl.v_gld_b][q_idx] + v_gld_p_gpf = [ctrl.v_gld_a_gpf, ctrl.v_gld_b_gpf][p_idx] + v_gld_p_num = [ctrl.v_gld_a_num, ctrl.v_gld_b_num][p_idx] + a_c = ctrl.a_c v_q = [ctrl.v_a, ctrl.v_b][q_idx] v_sld_q_os = [ctrl.v_sld_a_os, ctrl.v_sld_b_os][q_idx] @@ -155,8 +164,12 @@ def emit_single_pass_through(self): wave_repeat_p = [cxm.wave_repeat_m, cxm.wave_repeat_n][p_idx] wave_repeat_q = [cxm.wave_repeat_m, cxm.wave_repeat_n][q_idx] - v_pack_p = [ctrl.pass_through_v_pack_a, ctrl.pass_through_v_pack_b][p_idx] - v_pack_q = [ctrl.pass_through_v_pack_a, ctrl.pass_through_v_pack_b][q_idx] + p_interleave_gld = [ctrl.pass_through_a_interleave_gld, ctrl.pass_through_b_interleave_gld][p_idx] + + assert wave_repeat_q == 2, "currently the side need LDS must have repeat 2, following limitation seems have BUG" + + v_pack_p = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][p_idx] + v_pack_q = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][q_idx] assert v_pack_p == v_pack_q, "currently only support p, q the same" assert unroll_k % (v_pack_p * k_per_inst) == 0 @@ -190,7 +203,7 @@ def call_mbb(mbb): # parse global load of p tensor into list of single load mbb_gld_p = create_machine_basic_block(global_load_p()) - mbb_gld_q = create_machine_basic_block(global_load_q()) + mbb_gld_q = create_machine_basic_block(global_load_q(), merge_mbb = 1) mbb_p_clear = 1 if mbb_gld_p[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 mbb_q_clear = 1 if mbb_gld_q[0].mc_inst(-1).type() == MC_INST_TYPE_LEGACY_MACRO else 0 @@ -205,7 +218,7 @@ def call_mbb(mbb): num_gld_p_per_issue = num_gld_p // (len(mbb_gld_p) - mbb_p_clear) def emit_v_clear_nc_p(i): with self._deferred_context(): - self._emit(f".v_clear_nc {v_gld_p(i * num_gld_p_per_issue)}, {num_gld_p_per_issue}") + self._emit(f".v_clear_nc {v_gld_p(i * num_gld_p_per_issue) if p_interleave_gld else v_gld_p_gpf(i * num_gld_p_per_issue)}, {num_gld_p_per_issue}") return self._get_deferred() mbb_gld_p_wrapper = list() @@ -343,18 +356,22 @@ def mfma_per_k_slot(i_k, i_mfma_v_pack_slot, is_last_fma): cnt_mfma = 0 def try_do_gld_per_slot(i_slot): if is_last_fma: - if i_k == 0: - mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + if p_interleave_gld: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] if i_k == 0 else list() else: mbb_gld_p_per_k = list() mbb_gld_per_k = mbb_gld_p_per_k else: - if i_k == 0: - mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + if p_interleave_gld: + if i_k == 0: + mbb_gld_p_per_k = mbb_gld_p[len(mbb_gld_p) - gld_p_per_k : ] + else: + start_p_idx = mbb_p_clear if i_k == 1 else ((i_k - 1) * gld_p_per_k + mbb_p_clear) # always no clear + mbb_gld_p_per_k = mbb_gld_p[start_p_idx : i_k * gld_p_per_k + mbb_p_clear ] else: - start_p_idx = mbb_p_clear if i_k == 1 else ((i_k - 1) * gld_p_per_k + mbb_p_clear) # always no clear - mbb_gld_p_per_k = mbb_gld_p[start_p_idx : i_k * gld_p_per_k + mbb_p_clear ] - mbb_gld_per_k = mbb_gld_p_per_k + mbb_msw_pq + mbb_msw_acc + mbb_gld_q if i_k == 0 else mbb_gld_p_per_k + mbb_gld_p_per_k = mbb_gld_p if i_k == 0 else list() + mbb_gld_per_k = ((mbb_gld_p_per_k + mbb_msw_pq + mbb_msw_acc + mbb_gld_q) if p_interleave_gld else (mbb_gld_q + mbb_gld_p_per_k)) \ + if i_k == 0 else mbb_gld_p_per_k num_gld_slot_per_k = wave_repeat_p * wave_repeat_q * v_pack_p num_gld_per_slot = utility_next_mul(len(mbb_gld_per_k), num_gld_slot_per_k) // num_gld_slot_per_k for i_gld in range(num_gld_per_slot): @@ -388,34 +405,54 @@ def do_sld_q(i_v, i_r): if i_k == 0: + if not p_interleave_gld and not is_last_fma: + self._emit(move_slice_window_pq()) for mbb_1st in mbb_first_sld[1:]: self._emit(call_mbb(mbb_1st)) + if not p_interleave_gld and not is_last_fma: + self._emit(move_slice_window_acc()) for i_rp in range(wave_repeat_p): # cnt_p_load = cnt_p_load + 1 for i_rq in range(wave_repeat_q): - if i_rq != 0: - vmcnt_str = "" + num_lgkmcnt = (pref + rept - 2) - ((pref - 1 + i_rq) if i_k == (unroll_k_slot-1) else 0) + if not p_interleave_gld: + vmcnt_str = "vmcnt(0)" if i_k == 0 and i_rp == 0 and i_rq == 0 else \ + ( f"vmcnt({f_gld_p.get_issues()})" if num_lgkmcnt == 0 and not is_last_fma else "") else: - if i_k == 0: - vmcnt_str = f'vmcnt({num_p_issue - 1 - gld_p_per_k})' + if i_rq != 0 and wave_repeat_q != 1: + vmcnt_str = "" else: - if not is_last_fma: - vmcnt_str = f'vmcnt({num_p_issue + num_q_issue - 2})' + if i_k == 0: + vmcnt_str = f'vmcnt({num_p_issue - 1 - gld_p_per_k})' else: - vmcnt_str = f'vmcnt({num_p_issue - i_k - 1})' - num_lgkmcnt = (pref + rept - 2) - ((pref - 1 + i_rq) if i_k == (unroll_k_slot-1) else 0) - if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and num_lgkmcnt == 0: - # we need a change to put last lgkmcnt earlier - assert vmcnt_str == "" + if not is_last_fma: + vmcnt_str = f'vmcnt({num_p_issue + num_q_issue - 2})' + else: + vmcnt_str = f'vmcnt({num_p_issue - i_k - 1})' + + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and num_lgkmcnt == 0 and p_interleave_gld: + # we need a chance to put last lgkmcnt earlier + # assert vmcnt_str == "" if is_last_fma: - self._emit(f's_waitcnt lgkmcnt(0)') + self._emit(f's_waitcnt lgkmcnt(0) ; vmcnt_str:{vmcnt_str}') + else: + # self._emit(f"; __ vmcnt_str:{vmcnt_str}") + pass else: self._emit(f's_waitcnt lgkmcnt({num_lgkmcnt}) {vmcnt_str}') + if num_lgkmcnt == 0 and not p_interleave_gld and not is_last_fma: + # self._emit(move_slice_window_acc()) + do_sst_q() + if i_k == 0 and i_rp == 0 and i_rq == 0: + if not p_interleave_gld and v_gld_p_gpf: + # move buffer + for i_pnum in range(v_gld_p_num): + self._emit(f"v_mov_b32 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") for i_v in range(v_pack_p): self._emit(mfma_step_pxq_vk(i_k, i_rp, i_rq, i_v, i_local_buffer_q)) - if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT: + if MFMA_FEAT_SINGLE_PASS_THROUGH_EARLY_LAST_DS_WAIT and p_interleave_gld: if (i_mfma_v_pack_slot == mfma_v_pack_slot - 2) and (v_pack_p == 1 or i_v == (v_pack_p // 2) - 1): assert i_rq == 0 if not is_last_fma: @@ -461,21 +498,12 @@ def do_sld_q(i_v, i_r): self._emit(f".v_clear_acc_c {a_c()}, {cxm.total_acc_c()}") # self._emit(f"; make sure acc WAR harzard, at least 1 nop for src_c") - self._emit(f"s_waitcnt vmcnt({f_gld_p.get_issues() - wave_repeat_p * wave_step_p})") + self._emit(f"s_waitcnt vmcnt({f_gld_p.get_issues() - ((wave_repeat_p * wave_step_p) if p_interleave_gld else 0)})") self._emit(f_sst_q()) self._emit_empty_line() - - # decrese k - # self._emit(f"s_sub_i32 s[{s_kitr()}], s[{s_knum()}], {unroll_k}") - # self._emit(f"s_cmp_gt_i32 s[{s_kitr()}], 0") - # self._emit(f"s_cbranch_scc0 {label_mfma_end}") - # self._emit_empty_line() - - # right after clear acc - # self._emit(f_move_slice_window_p()) - # self._emit(f_move_slice_window_q()) - # if f_move_slice_window_acc != None: - # self._emit(f_move_slice_window_acc()) + # if not p_interleave_gld: + # self._emit(move_slice_window_pq()) + # self._emit(move_slice_window_acc()) self._emit(f"s_waitcnt lgkmcnt(0)") self._emit(f"s_barrier") diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index cb21c84e..cf74e255 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -310,10 +310,12 @@ def serialize(self): ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 32 , 32, 32, 2, 2, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 64 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 4, 1, 1, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 32 , 32 , 16, 16, 4, 2, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), #ctrl_xdlops_mapping_t( 256, 4 , 64, 4 , 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm #ctrl_xdlops_mapping_t( 4 , 256, 4 , 64, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), # TODO: small/skinny gemm ctrl_xdlops_mapping_t( 64 , 16 , 64, 4 , 1, 4, 1, 1, 1, 1, v_mfma_f32_4x4x1f32), From 830c21d5e0cf491c6c05c8cc6601683c7e241946 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 2 Mar 2021 18:19:45 +0800 Subject: [PATCH 25/40] enable magic div by default --- driver/igemm_fwd_gtc_driver.h | 4 ++-- igemm/algo/igemm_base.py | 4 ++-- igemm/algo/igemm_fwd_gtc_nhwc.py | 4 ++-- igemm/codegen/compile.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 6d9058e8..f2e28125 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -92,7 +92,7 @@ typedef struct { int x; int group; #if USE_MAGIC_DIV - uint32_t magic_0; // denom: gemm_n / n_per_block + uint32_t magic_0; // denom: (gemm_n + n_per_block - 1) / n_per_block uint32_t magic_1; // denom: ho*wo uint32_t magic_2; // denom: wo uint32_t magic_3; // denom: (gemm_m/m_per_block) * (gemm_n/n_per_block) @@ -532,7 +532,7 @@ class igemm_fwd_gtc_t { int gemm_m = n * ho * wo; int gemm_n = k / group; - magic_div_u32_t mdiv_0 = magic_div_u32_gen(gemm_n / gemm_n_per_block); + magic_div_u32_t mdiv_0 = magic_div_u32_gen((gemm_n + gemm_n_per_block - 1) / gemm_n_per_block); magic_div_u32_t mdiv_1 = magic_div_u32_gen(ho*wo); magic_div_u32_t mdiv_2 = magic_div_u32_gen(wo); magic_div_u32_t mdiv_3 = magic_div_u32_gen((gemm_m/gemm_m_per_block) * (gemm_n/gemm_n_per_block)); diff --git a/igemm/algo/igemm_base.py b/igemm/algo/igemm_base.py index 364d96b2..7dbe82a4 100755 --- a/igemm/algo/igemm_base.py +++ b/igemm/algo/igemm_base.py @@ -36,7 +36,7 @@ IGEMM_GTC_FEAT_PRECACHE_SOFFSET = 1 IGEMM_GTC_FEAT_LOCAL_PREFETCH = 1 IGEMM_GTC_FEAT_FMA_INTERLEAVE = 1 -IGEMM_GTC_FEAT_MAGIC_DIVISION = 0 +IGEMM_GTC_FEAT_MAGIC_DIVISION = 1 IGEMM_GTC_FEAT_SOURCE_ACCESS_ENCODING_KERNEL_NAME = 0 # IGEMM_GTC_TENSOR_LAYOUT_NCHW = ((1 << 4) | 0) @@ -189,7 +189,7 @@ def __init__(self, tunable_dict): default_source_access_order = IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M if (self.direction == 'fwd' and self.tensor_layout == 'nchw') \ else IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_M_GEMM_N - self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', IGEMM_GTC_TUNABLE_SOURCE_ACCESS_ORDER_GEMM_N_GEMM_M) + self.source_access_order = utility_dict_with_default_t(tunable_dict)('source_access_order', default_source_access_order) self.gemm_m_unmerge_cluster = utility_dict_with_default_t(tunable_dict)('gemm_m_unmerge_cluster', 0) self.gemm_n_unmerge_cluster = utility_dict_with_default_t(tunable_dict)('gemm_n_unmerge_cluster', 0) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index 4a382b71..a64f5d3a 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -537,8 +537,8 @@ def __init__(self, mc, outer): # allocate several sgpr to hold magic/shift value. self.s_magic_0 = sym_t("s_magic_0" ,self.s_p_in.value + 2) self.s_magic_1 = sym_t("s_magic_1" ,self.s_p_in.value + 3) - self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_wei.value + 2) - self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_wei.value + 3) + self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_out.value + 2) + self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_out.value + 3) self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_flag_need_acc_yx.value) self.s_tmp = sym_t("s_tmp" ,sseq(6, 2)) diff --git a/igemm/codegen/compile.py b/igemm/codegen/compile.py index ad8fad9e..820f8550 100644 --- a/igemm/codegen/compile.py +++ b/igemm/codegen/compile.py @@ -31,7 +31,7 @@ IGEMM_HOST_USE_GPU_NAIVE_CONV = True IGEMM_HOST_USE_XDNN = False -IGEMM_HOST_USE_MAGIC_DIV = False +IGEMM_HOST_USE_MAGIC_DIV = True IGEMM_HOST_USE_HIPCC = True # hipclang perfer use hipcc to compile host code def _check_hip_clang(): From 7cbc45b5b43d661eeb50aab741c4cf2e2ba9d0f8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 2 Mar 2021 18:29:22 +0800 Subject: [PATCH 26/40] fix a bug in magic number calculation --- driver/igemm_fwd_gtc_driver.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index f2e28125..fc7f92b9 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -532,10 +532,10 @@ class igemm_fwd_gtc_t { int gemm_m = n * ho * wo; int gemm_n = k / group; - magic_div_u32_t mdiv_0 = magic_div_u32_gen((gemm_n + gemm_n_per_block - 1) / gemm_n_per_block); + magic_div_u32_t mdiv_0 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); magic_div_u32_t mdiv_1 = magic_div_u32_gen(ho*wo); magic_div_u32_t mdiv_2 = magic_div_u32_gen(wo); - magic_div_u32_t mdiv_3 = magic_div_u32_gen((gemm_m/gemm_m_per_block) * (gemm_n/gemm_n_per_block)); + magic_div_u32_t mdiv_3 = magic_div_u32_gen(utility_integer_divide_ceil(gemm_m, gemm_m_per_block) * utility_integer_divide_ceil(gemm_n, gemm_n_per_block)); karg.magic_0 = mdiv_0.magic; karg.magic_1 = mdiv_1.magic; karg.magic_2 = mdiv_2.magic; From 6b2ddb232440d211c3e75a5a969fca28f89269a8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Wed, 3 Mar 2021 10:32:03 +0800 Subject: [PATCH 27/40] update small change --- igemm/algo/igemm_fwd_gtc_nhwc.py | 6 ++++-- igemm/algo/xdlops_mapping.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index a64f5d3a..a4237b1c 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -227,7 +227,8 @@ def expr(self): self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_c_itr()}]") else: self._emit(f"s_cmp_le_u32 s[{self.s_gemm_k_num_c()}], s[{self.s_in_offset()}]") - self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 1, 0") + if not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through: + self._emit(f"s_cselect_b32 s[{self.s_flag_need_acc_yx()}], 1, 0") self._emit_empty_line() class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): @@ -278,7 +279,8 @@ def expr(self): assert 'm_set_flag_nhw' in self.options m_set_flag_nhw = self.options['m_set_flag_nhw'] - self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") + if not self.tunable.tensor_a_pass_through and not self.tunable.tensor_b_pass_through: + self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") self._emit(f"s_cbranch_scc0 {label_acc_yx_end} ; no need do accumulate yx") self._emit_front(f"{label_acc_yx}:") if self.tunable.tensor_a_pass_through: diff --git a/igemm/algo/xdlops_mapping.py b/igemm/algo/xdlops_mapping.py index cf74e255..f9a5c46b 100755 --- a/igemm/algo/xdlops_mapping.py +++ b/igemm/algo/xdlops_mapping.py @@ -293,6 +293,7 @@ def serialize(self): ctrl_xdlops_mapping_t( 128, 64 , 32, 32, 2, 1, 2, 2, 2, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 8 , 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 128, 32 , 16, 16, 4, 4, 2, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 128, 32 , 32, 32, 2, 4, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 128, 32 , 32, 32, 2, 2, 2, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 128, 8 , 32, 1, 4, 2, 2, 1, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 128, 16, 64, 1, 4, 1, 1, 1, 1, v_mfma_f32_16x16x1f32), @@ -310,6 +311,7 @@ def serialize(self): ctrl_xdlops_mapping_t( 64 , 32 , 32, 8 , 1, 4, 1, 1, 1, 2, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 2, 1, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 32 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), + ctrl_xdlops_mapping_t( 64 , 48 , 16, 16, 4, 4, 1, 3, 1, 1, v_mfma_f32_16x16x4f32), ctrl_xdlops_mapping_t( 64 , 32 , 32, 32, 2, 2, 1, 1, 1, 1, v_mfma_f32_32x32x2f32), ctrl_xdlops_mapping_t( 32 , 64 , 8 , 32, 1, 4, 1, 1, 2, 1, v_mfma_f32_4x4x1f32), ctrl_xdlops_mapping_t( 32 , 64 , 16, 16, 4, 4, 1, 2, 1, 1, v_mfma_f32_16x16x4f32), From c40dc9bbc96988788ae4f15882a22b6d7d5377b7 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 11 Mar 2021 18:44:15 +0800 Subject: [PATCH 28/40] inference prototype for gfx1030 --- driver/args.h | 6 +- test/inference/build.sh | 13 + test/inference/igemm_fwd_btm_nhwc.h | 271 ++++++++++++++ test/inference/test_inference.cpp | 553 ++++++++++++++++++++++++++++ 4 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 test/inference/build.sh create mode 100644 test/inference/igemm_fwd_btm_nhwc.h create mode 100644 test/inference/test_inference.cpp diff --git a/driver/args.h b/driver/args.h index 079ce7eb..83a16d7c 100644 --- a/driver/args.h +++ b/driver/args.h @@ -159,8 +159,10 @@ class args_t { static inline args_t create_conv_args(int argc, char *argv[]) { const std::string base("conv"); - if (argc >= 2 && argv[1] != base) { - printf("not proper base arg name"); + const std::string base_fp16("convfp16"); + const std::string base_bf16("convbf16"); + if (argc >= 2 && argv[1] != base && argv[1] != base_fp16 && argv[1] != base_bf16) { + printf("not proper base arg name\n"); exit(1); } diff --git a/test/inference/build.sh b/test/inference/build.sh new file mode 100644 index 00000000..6c5a6bf2 --- /dev/null +++ b/test/inference/build.sh @@ -0,0 +1,13 @@ +#!/bin/sh +# to launch from top of generator +ARCH=gfx1030 +rm -rf out +mkdir out + +/opt/rocm/hip/bin/hipcc -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/ test/inference/kernel/igemm_fwd_btm_nhwc_fp16.s -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 +/opt/rocm/hip/bin/hipcc -x hip --cuda-gpu-arch=$ARCH --cuda-device-only -c -O3 driver/gpu_naive_conv/naive_conv.cpp -o out/naive_conv.hsaco + + + + diff --git a/test/inference/igemm_fwd_btm_nhwc.h b/test/inference/igemm_fwd_btm_nhwc.h new file mode 100644 index 00000000..cc96ae38 --- /dev/null +++ b/test/inference/igemm_fwd_btm_nhwc.h @@ -0,0 +1,271 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + *all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef __DIRECT_CONV_DRIVER_H +#define __DIRECT_CONV_DRIVER_H + + +#include +#include +#include +#include +#include + +#ifndef HIP_CALL +#define HIP_CALL(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + printf("[hiperror](%d) fail to call %s,(%s)", (int)err, #call, \ + hipGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) +#endif + +static inline size_t gpu_conv_out_size(size_t in_size, size_t pad, + size_t dilation, size_t ksize, + size_t stride) { + return (in_size + 2 * pad - dilation * (ksize - 1) - 1) / stride + 1; +} + +typedef struct { + void * p_in; + void * p_wei; + void * p_out; + uint32_t hi; + uint32_t wi; + uint32_t n; + uint32_t k_per_group; + uint32_t c_per_group; + uint32_t ho; + uint32_t wo; + uint32_t sy; + uint32_t sx; + uint32_t dy; + uint32_t dx; + uint32_t py; + uint32_t px; + uint32_t fy; + uint32_t fx; + uint32_t group; + uint32_t batch_m; + uint32_t stride_m; + uint32_t magic_0; + uint32_t magic_1; + uint32_t shift_pack_0; + uint32_t __pack_0; +} __attribute__((packed)) igemm_fwd_btm_2d_karg_t; +static inline void dump_igemm_fwd_btm_2d_karg(igemm_fwd_btm_2d_karg_t * karg) +{ + std::cout<<"p_in:"<p_in<<", "; + std::cout<<"p_wei:"<p_wei<<", "; + std::cout<<"p_out:"<p_out<<", "; + std::cout<<"hi:"<hi<<", "; + std::cout<<"wi:"<wi<<", "; + std::cout<<"n:"<n<<", "; + std::cout<<"k_per_group:"<k_per_group<<", "; + std::cout<<"c_per_group:"<c_per_group<<", "; + std::cout<<"ho:"<ho<<", "; + std::cout<<"wo:"<wo<<", "; + std::cout<<"sy:"<sy<<", "; + std::cout<<"sx:"<sx<<", "; + std::cout<<"dy:"<dy<<", "; + std::cout<<"dx:"<dx<<", "; + std::cout<<"py:"<py<<", "; + std::cout<<"px:"<px<<", "; + std::cout<<"fy:"<fy<<", "; + std::cout<<"fx:"<fx<<", "; + std::cout<<"group:"<group<<", "; + std::cout<<"batch_m:"<batch_m<<", "; + std::cout<<"stride_m:"<stride_m<<", "; + std::cout<<"magic_0:"<magic_0<<", "; + std::cout<<"magic_1:"<magic_1<<", "; + std::cout<<"shift_pack_0:"<shift_pack_0<= 1000) + num_cu *= 2; + } + ~igemm_fwd_btm_t(){} + std::string get_kernel_name(const igemm_fwd_btm_kernel_info_t *kernel_info) { + return kernel_info->kernel_name; + } + + result_t run(const args_t *arg, hipModule_t module, igemm_fwd_btm_kernel_info_t * kernel_info, + void *p_in, void *p_wei, void *p_out, + int warmup, int repeat, const driverDataType_t& data_type) { + size_t hi = arg->get_int("in_h"); + size_t wi = arg->get_int("in_w"); + size_t n = arg->get_int("batchsize"); + size_t k = arg->get_int("out_channels"); + size_t c = arg->get_int("in_channels"); + + size_t sy = arg->get_int("conv_stride_h"); + size_t sx = arg->get_int("conv_stride_w"); + size_t dy = arg->get_int("dilation_h"); + size_t dx = arg->get_int("dilation_w"); + size_t py = arg->get_int("pad_h"); + size_t px = arg->get_int("pad_w"); + size_t fy = arg->get_int("fil_h"); + size_t fx = arg->get_int("fil_w"); + size_t ho = gpu_conv_out_size(hi, py, dy, fy, sy); + size_t wo = gpu_conv_out_size(wi, px, dx, fx, sx); + size_t group = arg->get_int("group_count"); + + assert(c % group == 0 && k % group == 0); + + assert(group != 0 && c % group == 0 && k % group == 0); + + size_t k_per_group = k / group; + size_t c_per_group = c / group; + igemm_fwd_btm_2d_karg_t karg; + karg.p_in = p_in; + karg.p_wei = p_wei; + karg.p_out = p_out; + karg.hi = static_cast(hi); + karg.wi = static_cast(wi); + karg.n = static_cast(n); + karg.k_per_group = static_cast(k_per_group); + karg.c_per_group = static_cast(c_per_group); + karg.ho = static_cast(ho); + karg.wo = static_cast(wo); + karg.sy = static_cast(sy); + karg.sx = static_cast(sx); + karg.dy = static_cast(dy); + karg.dx = static_cast(dx); + karg.py = static_cast(py); + karg.px = static_cast(px); + karg.fy = static_cast(fy); + karg.fx = static_cast(fx); + karg.group = static_cast(group); + size_t karg_size = sizeof(karg); + + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, + HIP_LAUNCH_PARAM_END}; + + hipFunction_t kernel_func; + HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_info->kernel_name.c_str())); + + int block_size = kernel_info->block_size; + int grid_size = kernel_info->occupancy * num_cu; + grid_size = env_get_int("GRID_SIZE", grid_size); + int b_grids = (ho * wo + kernel_info->m_per_block - 1) / kernel_info->m_per_block; + + karg.batch_m = (b_grids + grid_size - 1) / grid_size; + karg.stride_m = kernel_info->m_per_block * grid_size; + + magic_div_u32_t mdiv_0 = magic_div_u32_gen(fx); + magic_div_u32_t mdiv_1 = magic_div_u32_gen(wo); + karg.magic_0 = mdiv_0.magic; + karg.magic_1 = mdiv_1.magic; + karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, 0, 0); + + // printf("launch fwd block:%d, grid:%d\n", block_size, grid_size); + // dump_igemm_fwd_btm_2d_karg(&karg); + + auto launch_fwd = [&]() -> float { + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, + HIP_LAUNCH_PARAM_END}; + float ms = .0; + + hipEvent_t start; + hipEvent_t stop; + hipEventCreate(&start); + hipEventCreate(&stop); + + // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, n, group, + block_size, 1, 1, 0, 0, NULL, + (void **)&config, start, stop)); + + hipEventSynchronize(stop); + hipEventElapsedTime(&ms, start, stop); + hipEventDestroy(start); + hipEventDestroy(stop); + + return ms; + }; + + for (int i = 0; i < warmup; i++) { + launch_fwd(); + } + + std::vector duration_list; + for (int i = 0; i < repeat; i++) { + float d = launch_fwd(); + duration_list.push_back(d); + } + + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + + usleep(1000 * 1); + + result_t result; + result.return_code = 0; + result.duration_ms = avg_duration; + result.kernel_name = kernel_info->kernel_name; + return result; + } +}; + +#endif diff --git a/test/inference/test_inference.cpp b/test/inference/test_inference.cpp new file mode 100644 index 00000000..0ee64767 --- /dev/null +++ b/test/inference/test_inference.cpp @@ -0,0 +1,553 @@ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "args.h" + +#define USE_HALF_HPP + +#ifdef USE_HALF_HPP +#include "half.hpp" +using float16 = half_float::half; +#endif + +std::string parse_base_arg(int argc, char* argv[]) +{ + if(argc < 2) + { + printf("Invalid Number of Input Arguments\n"); + exit(0); + } + + std::string arg = argv[1]; + + if(arg != "conv" && arg != "convfp16" && arg != "--version") + { + printf("Invalid Base Input Argument\n"); + exit(0); + } + else if(arg == "-h" || arg == "--help" || arg == "-?") + exit(0); + else + return arg; +} + +static inline size_t conv_out_size(size_t in_size, size_t pad, size_t dilation, + size_t ksize, size_t stride) { + return (in_size + 2 * pad - dilation * (ksize - 1) - 1) / stride + 1; +} +typedef struct { + uint32_t magic; + uint8_t shift; +} magic_div_u32_t; +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) { + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for (shift = 0; shift < 32; shift++) + if ((1U << shift) >= d) + break; + + uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + magic_div_u32_t result; + result.magic = magic; + result.shift = shift; + return result; +} +static inline uint32_t magic_div_u32_pack_shift(uint8_t s0, uint8_t s1, uint8_t s2, uint8_t s3) +{ + uint32_t shift_0 = static_cast(s0); + uint32_t shift_1 = static_cast(s1); + uint32_t shift_2 = static_cast(s2); + uint32_t shift_3 = static_cast(s3); + return (shift_3 << 24) | (shift_2 << 16) | (shift_1 << 8) | shift_0; +} +typedef struct { + int return_code; + float duration_ms; + float gflops; + float efficiency; + std::string kernel_name; +} result_t; + + +typedef enum { + driverHalf = 0, /*!< 16-bit floating point (Fully supported) */ + driverFloat = 1, /*!< 32-bit floating point (Fully supported) */ + driverBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction) + (Partially supported) */ +} driverDataType_t; + +static inline int env_get_int(const char *var_name, int default_int) { + char *v = getenv(var_name); + int r = default_int; + if (v) + r = atoi(v); + return r; +} + +#define NAIVE_CONV_THREADED +#include "naive_conv.h" +#include "gpu_naive_conv.h" +#include "igemm_fwd_btm_nhwc.h" + + +#define HIP_CALL(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + printf("[hiperror](%d) fail to call %s,(%s)", (int)err, #call, \ + hipGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) + +static int gen_rand_integer() +{ + static int inited = 0; + if(inited == 0) + { + std::srand(std::time(nullptr)); + inited = 1; + } + return std::rand(); +} + + +static inline char *env_get_str(const char *var_name, char *default_str) { + char *v = getenv(var_name); + if (v) + return v; + return default_str; +} + +template +void block_wise_tensor_copy(Dst_T *p_dst, Src_T *p_src, int tid, int block_size, int total_size) +{ + for (int i = tid; i < total_size; i += block_size) { + p_dst[i] = static_cast(p_src[i]); + } +} + +template +void tensor_copy(Dst_T *p_dst, Src_T *p_src, size_t tensor_size) { + int num_threads = std::thread::hardware_concurrency(); + if (num_threads < 4) + num_threads = 4; + + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.push_back(std::thread(block_wise_tensor_copy, + p_dst, p_src, t, num_threads, tensor_size)); + } + for (auto &th : threads) + th.join(); +} + +template +struct distribution_t{ +}; + +template <> +struct distribution_t{ + distribution_t(int min, int max) : distribution(min, max) {} + template + int operator()(URNG & rng){ return distribution(rng);} + std::uniform_int_distribution distribution; +}; +template <> +struct distribution_t{ + distribution_t(float min, float max) : distribution(min, max) {} + template + float operator()(URNG & rng){ return distribution(rng);} + std::uniform_real_distribution distribution; +}; + +template +void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale) +{ + std::mt19937 rng(std::chrono::system_clock::now() + .time_since_epoch() + .count() + + std::hash()(std::this_thread::get_id())); + distribution_t distribution(min,max); + for (int i = tid; i < total_size; i += block_size) { + p[i] = static_cast(scale * distribution(rng)); + } +} + +template +void gen_rand_vector(Dst_T *vec, size_t vec_size, Src_T fmin, Src_T fmax, Src_T scale = 1) { + int num_threads = std::thread::hardware_concurrency(); + if (num_threads < 4) + num_threads = 4; + // printf("total threads:%d\n",num_threads); + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.push_back(std::thread(block_wise_rand_generator, + vec, t, num_threads, vec_size, fmin, fmax, scale)); + } + for (auto &th : threads) + th.join(); +} + +static inline bool valid_float(float p) +{ + return !(std::isnan(p) || std::isinf(p)); +} +#ifndef ABS +#define ABS(b) ((b) > 0 ? (b) : -1 * (b)) +#endif +template +static inline bool valid_vector(const float *ref, const T *pred, size_t n, + double nrms = 1e-6) { + double s0 = 0.0; + double s1 = 0.0; + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + size_t pp_err = 0; + + for (size_t i = 0; i < n; ++i) { + double ri = (double)ref[i]; + double pi = (double)pred[i]; + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + printf(" invalid float at %4zu, ref:%f, pred:%f\n", i, ri, pi); + return false; + } + double d = ri - pi; + double dd = d * d; + double rr = 2.0 * ri * ri; + s0 += dd; + s1 += rr; + if(igemm_per_pixel_check){ + double delta = ABS(ABS(ri - pi) / ri); + printf("[%zu] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, ((uint32_t *)pred)[i], delta > 3e-5? "N":"Y"); + if (delta > 3e-5) { + if(igemm_per_pixel_check_print){ + if (pp_err < 100) + printf("diff at %zu, ref:%lf, pred:%lf(0x%08x), d:%lf\n", i, ri, + pi, ((uint32_t *)pred)[i], delta); + } + pp_err++; + } + } + } + // printf("\nnrms:%lf, s0:%lf, s1:%lf, expected_nrms is %1f\n",sqrt(s0/s1),s0,s1,nrms); + return (sqrt(s0 / s1) < nrms) +#ifdef PER_PIXEL_CHECK + && (pp_err == 0) +#endif + ; +} + +static inline void dump_output_dword(const float *out, size_t n) +{ + for (size_t i = 0; i < n; ++i) { + double pi = (double)out[i]; + printf("[%zu] pred:%lf(0x%08x)\n", i, pi, ((uint32_t *)out)[i]); + } +} + +static inline double theoritical_gflops(double sclk_ghz, size_t cu, + size_t simd) { + return 2 * sclk_ghz * cu * simd; +} + +static inline double +theoritical_conv_flop(size_t n, size_t c, size_t hi, size_t wi, size_t k, + size_t y, size_t x, size_t stride_h, size_t stride_w, + size_t dilation_h, size_t dilation_w, size_t pad_h, + size_t pad_w, size_t ngroups) { + size_t ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + size_t wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + double flop = (double)n * c * ho * wo * k * y * x * 2 / ngroups; + return flop; +} +static inline double +measured_conv_gflops(double time_ms, size_t n, size_t c, size_t hi, + size_t wi, size_t k, size_t y, size_t x, + size_t stride_h, size_t stride_w, size_t dilation_h, + size_t dilation_w, size_t pad_h, size_t pad_w, size_t ngroups) { + double flop = + theoritical_conv_flop(n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); + return flop / (time_ms * 1e6); +} + +static inline double get_nrms(int forw, driverDataType_t driver_data_type){ + auto basic_tolerance = [=]() -> double{ + if (driver_data_type == driverFloat){ +#ifdef USE_XDNN + return 5e-5; +#else + return 1.5e-6; +#endif + } + else if (driver_data_type == driverHalf){ +#ifdef USE_XDNN + return 5*8.2e-3; +#else + return 8.2e-3; +#endif + } + }; + double nrms = basic_tolerance(); + // wrw has a high tolerance + if (forw == 4){ + nrms *= 2; + if(driver_data_type == driverFloat){ + nrms = 0.01; + } + else if(driver_data_type == driverHalf){ + nrms *= 5; + } + } + return nrms; +} + +#define GPU_NAIVE_CONV_HSACO "naive_conv.hsaco" +#define SCLK_MHZ 2200 +#define WARMUP 3 +#define REPEAT 8 + +#ifndef HSACO +#define HSACO "igemm_fwd_btm_nhwc_fp16.hsaco" +#endif +int main(int argc, char **argv){ + int warmup = env_get_int("WARMUP", WARMUP); + int repeat = env_get_int("REPEAT", REPEAT); + int sclk_mhz = env_get_int("SCLK_MHZ", SCLK_MHZ); + int dump_out = env_get_int("DUMP_OUT", 0); + char *hsaco = env_get_str("HSACO", HSACO); + char *gpu_naive_conv_hsaco = env_get_str("GPU_NAIVE_CONV_HSACO", GPU_NAIVE_CONV_HSACO); + gpu_naive_conv_init(gpu_naive_conv_hsaco); + + std::string base_arg = parse_base_arg(argc, argv); + + driverDataType_t driver_data_type; + if(base_arg == "conv") + driver_data_type = driverFloat; + else if(base_arg == "convfp16") + driver_data_type = driverHalf; + else if(base_arg == "convbf16") { + driver_data_type = driverBFloat16; + exit(0); + } + else + exit(0); + + hipModule_t module; + HIP_CALL(hipModuleLoad(&module, hsaco)); + + args_t conv_args = create_conv_args(argc, argv); + // dump_arg(&conv_args); + + int hi = conv_args.get_int("in_h"); + int wi = conv_args.get_int("in_w"); + int n = conv_args.get_int("batchsize"); + int k = conv_args.get_int("out_channels"); + int c = conv_args.get_int("in_channels"); + + int stride_h = conv_args.get_int("conv_stride_h"); + int stride_w = conv_args.get_int("conv_stride_w"); + int dilation_h = conv_args.get_int("dilation_h"); + int dilation_w = conv_args.get_int("dilation_w"); + int pad_h = conv_args.get_int("pad_h"); + int pad_w = conv_args.get_int("pad_w"); + int y = conv_args.get_int("fil_h"); + int x = conv_args.get_int("fil_w"); + int ngroups = conv_args.get_int("group_count"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + int forw = conv_args.get_int("forw"); + + int need_fwd = (forw == 0 ? 1 : (forw & 1 ? 1 : 0)); + int need_bwd = (forw == 0 ? 1 : (forw & 2 ? 1 : 0)); + int need_wrw = (forw == 0 ? 1 : (forw & 4 ? 1 : 0)); + + // init host side + float *host_input = (float *)malloc(static_cast(n) * c * hi * wi * sizeof(float)); + float *host_weight = (float *)malloc(static_cast(k) * c * y * x * sizeof(float)); + float *host_output = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); + + float *device_input; + float *device_weight; + float *device_output; + + HIP_CALL(hipMalloc(&device_input, static_cast(n) * c * hi * wi * sizeof(float))); + HIP_CALL(hipMalloc(&device_weight, static_cast(k) * c * y * x * sizeof(float))); + HIP_CALL(hipMalloc(&device_output, static_cast(n) * k * ho * wo * sizeof(float))); + +#ifdef USE_HALF_HPP + // fp16 type + float16 *host_input_f16 = (float16 *)malloc(n * c * hi * wi * sizeof(float16)); + float16 *host_weight_f16 = (float16 *)malloc(k * c * y * x * sizeof(float16)); + float16 *host_output_f16 = (float16 *)malloc(n * k * ho * wo * sizeof(float16)); + + float16 *device_input_f16; + float16 *device_weight_f16; + float16 *device_output_f16; + + HIP_CALL(hipMalloc(&device_input_f16, n * c * hi * wi * sizeof(float16))); + HIP_CALL(hipMalloc(&device_weight_f16, k * c * y * x * sizeof(float16))); + HIP_CALL(hipMalloc(&device_output_f16, n * k * ho * wo * sizeof(float16))); +#endif + + int need_verify = conv_args.get_int("verify"); + + int num_cu; + int num_simd = 64; // hard coded + int gcn_arch = 0; + + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + gcn_arch = dev_prop.gcnArch; + if(gcn_arch >= 1000) + num_cu *= 2; +#if 0 +#define P_DEVICE_PROP_INT(prop) \ + printf(#prop":%d\n", dev_prop.prop) + + + P_DEVICE_PROP_INT(clockRate); + P_DEVICE_PROP_INT(memoryClockRate); + P_DEVICE_PROP_INT(memoryBusWidth); + P_DEVICE_PROP_INT(major); + P_DEVICE_PROP_INT(minor); + P_DEVICE_PROP_INT(gcnArch); +#endif + } + + double theo_gflops = theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * 2/*fp16, 2x speed*/); + double nrms = get_nrms(forw, driver_data_type); + + printf("num_cu:%d, gcn_arch:%d, theo_gflops:%f\n", num_cu, gcn_arch, theo_gflops); + + if (need_fwd){ + float *device_output_to_host = NULL; + if (need_verify) { + // gen rand + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); + //gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); + gen_rand_vector(host_input, static_cast(n) * c * hi * wi, -5, 5); + gen_rand_vector(host_weight, static_cast(k) * c * y * x, -2, 2); + //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 1, 1); + //gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); + +#ifdef USE_HALF_HPP + if(driver_data_type == driverHalf){ + // move to different data type + tensor_copy(host_input_f16, host_input, static_cast(n) * c * hi * wi); + tensor_copy(host_weight_f16, host_weight, static_cast(k) * c * y * x); + } +#endif + + HIP_CALL(hipMemcpy(device_input, host_input, + static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight, host_weight, + static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + + gpu_naive_conv_fwd_nhwc_fp32(device_input, device_weight, device_output, + n, wi, hi, c, + k, x, y, pad_w, pad_h, stride_w, stride_h, + dilation_w, dilation_h, ngroups); + HIP_CALL(hipDeviceSynchronize()); + HIP_CALL(hipMemcpy(host_output, device_output, + static_cast(n) * k * ho * wo * sizeof(float), + hipMemcpyDeviceToHost)); + + if(driver_data_type == driverHalf){ +#ifdef USE_HALF_HPP + device_output_to_host = (float *)malloc((static_cast(n) * k * ho * wo * sizeof(float16) + 3) / 4 * 4); +#endif + } + else{ + device_output_to_host = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); + } + + } + if(driver_data_type == driverFloat){ + HIP_CALL(hipMemcpy(device_input, host_input, + static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight, host_weight, + static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + } +#ifdef USE_HALF_HPP + else if(driver_data_type == driverHalf){ + HIP_CALL(hipMemcpy(device_input_f16, host_input_f16, + static_cast(n) * c * hi * wi * sizeof(float16), hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight_f16, host_weight_f16, + static_cast(k) * c * y * x * sizeof(float16), hipMemcpyHostToDevice)); + } +#endif + igemm_fwd_btm_t conv_fwd_driver; + for (int i = 0; i < sizeof(igemm_fwd_btm_kernel_list)/sizeof(igemm_fwd_btm_kernel_list[0]); i++) { + igemm_fwd_btm_kernel_info_t *kinfo = &igemm_fwd_btm_kernel_list[i]; + + + printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(kinfo).c_str()); + fflush(stdout); + + result_t result; + +#ifdef USE_HALF_HPP + result = conv_fwd_driver.run(&conv_args, module, kinfo, device_input_f16, + device_weight_f16, device_output_f16, warmup, repeat, driver_data_type); +#endif + + double gflops = measured_conv_gflops( + result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); + printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, + gflops / 1000 , (gflops / theo_gflops) * 100); + if (need_verify) { + bool is_valid; + if(driver_data_type == driverFloat) { + HIP_CALL(hipMemcpy(device_output_to_host, device_output, + static_cast(n) * k * ho * wo * sizeof(float), + hipMemcpyDeviceToHost)); + is_valid = valid_vector(host_output, device_output_to_host, + static_cast(n) * k * ho * wo, nrms); + } +#ifdef USE_HALF_HPP + else if(driver_data_type == driverHalf) { + HIP_CALL(hipMemcpy(device_output_to_host, device_output_f16, + static_cast(n) * k * ho * wo * sizeof(float16), + hipMemcpyDeviceToHost)); + if(dump_out) + dump_output_dword(device_output_to_host, static_cast(n) * k * ho * wo / 2); + float16 *device_output_to_host_fp16 = (float16 *)device_output_to_host; + is_valid = valid_vector(host_output, device_output_to_host_fp16, + static_cast(n) * k * ho * wo, nrms); + } +#endif + printf(", valid:%s", is_valid ? "y" : "n"); + } + printf("\n"); + } + + if (need_verify){ + free(device_output_to_host); + } + } +} \ No newline at end of file From 6359fe5646d22cc7911c8b4ecdbf2d0f80af1241 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 11 Mar 2021 18:56:55 +0800 Subject: [PATCH 29/40] gfx1030 kernel asm --- test/inference/build.sh | 2 +- .../kernel/igemm_fwd_btm_nhwc_fp16.asm | 116 ++++ .../igemm_fwd_btm_nhwc_fp16_128x016.asm | 596 ++++++++++++++++++ .../igemm_fwd_btm_nhwc_fp16_256x016.asm | 589 +++++++++++++++++ 4 files changed, 1302 insertions(+), 1 deletion(-) create mode 100644 test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm create mode 100644 test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm create mode 100644 test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm diff --git a/test/inference/build.sh b/test/inference/build.sh index 6c5a6bf2..ea241fb5 100644 --- a/test/inference/build.sh +++ b/test/inference/build.sh @@ -5,7 +5,7 @@ rm -rf out mkdir out /opt/rocm/hip/bin/hipcc -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 -/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/ test/inference/kernel/igemm_fwd_btm_nhwc_fp16.s -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/ test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 /opt/rocm/hip/bin/hipcc -x hip --cuda-gpu-arch=$ARCH --cuda-device-only -c -O3 driver/gpu_naive_conv/naive_conv.cpp -o out/naive_conv.hsaco diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm new file mode 100644 index 00000000..cfa97782 --- /dev/null +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm @@ -0,0 +1,116 @@ +; pay attention to register bank of v_c, v_b +.macro .fma_1x16_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] + v_dot2c_f32_f16 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] + v_dot2c_f32_f16 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] + v_dot2c_f32_f16 v[\v_c+10], v[\v_a], v[\v_b+10] + v_dot2c_f32_f16 v[\v_c+11], v[\v_a], v[\v_b+11] + v_dot2c_f32_f16 v[\v_c+12], v[\v_a], v[\v_b+12] + v_dot2c_f32_f16 v[\v_c+13], v[\v_a], v[\v_b+13] + v_dot2c_f32_f16 v[\v_c+14], v[\v_a], v[\v_b+14] + v_dot2c_f32_f16 v[\v_c+15], v[\v_a], v[\v_b+15] +.endm + +.macro .fma_1x8_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] +.endm + +.macro .fma_1x4_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] +.endm + +.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp + s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] + s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] + s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] +.endm + +.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp + .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp + s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] + s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] +.endm + +.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp + v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] + v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] + v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] +.endm + +.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp + .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp + v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] + v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] +.endm + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.include "igemm_fwd_btm_nhwc_fp16_128x016.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x016.asm" + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_fwd_btm_nhwc_fp16_128x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_128x16x16_r3.kd + .sgpr_count: 58 + .vgpr_count: 74 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: __pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm new file mode 100644 index 00000000..9eb0104c --- /dev/null +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm @@ -0,0 +1,596 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_shift_pack_0, 104 + +.set s_block_ib, 2 ; bx, ho*wo +.set s_ka, 0 +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_shift_pack_0, 36 +.set s_shift_m0, 37 +.set s_shift_m1, s_shift_pack_0 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 38 +.set s_in_diff_hi, 39 +.set s_in_diff_wi, 40 +.set s_dilation_w_x, 41 +.set s_move_slice_k_ix, 42 + +.set s_kitr, 1 +.set s_wei_offset, 43 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 44 +.set s_br, 45 + +.set s_tmp, 46 +.set s_end, 52 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_a, 17 +.set v_ib, 25 +.set v_b, 26 +.set v_gld_a, 58 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+12 +.set v_wei_ix_list, v_b+15 +.set v_wei_flag, v_b+18 +.set v_wei_os, v_b+21 +.set v_tmp, v_b+24 +.set v_wei_ik, v_a +.set v_wei_ic, v_a+1 +.set v_wei_ie, v_a+2 +.set v_wei_flag_ik, v_a+3 +.set v_sst_b_os, v_a+4 +.set v_in_os, 66 +.set v_in_ihi, 67 +.set v_in_iwi, 68 +.set v_in_flag, 69 +.set v_out_os, 70 +.set v_out_flag, 71 +.set v_tid, 72 +.set v_end, 74 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_128x16x16_r3 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_128x16x16_r3,@function +igemm_fwd_btm_nhwc_fp16_128x16x16_r3: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0 + v_mov_b32 v[v_tid], v0 + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], 17*4 * 4 ; 17dword per row, 4 row + v_and_b32 v[v_tmp+5], 7, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + s_lshl_b32 s[s_block_ib], s[s_block_ib], 7 ; 128 half + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+2], s[s_wei_offset], v[v_wei_os+1] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + .mdiv_u32_rem_vs v_wei_ix_list+2,v_wei_iy_list+2,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 + ; 17 dword per row + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + v_add_nc_u32 v[v_sst_b_os+2], s[s_tmp+5], v[v_sst_b_os+1] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + s_mov_b32 s[s_sld_b_stride], 17*8*4 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+8], v[v_gld_b+9], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+10], v[v_gld_b+11], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + s_cmp_gt_i32 s[s_kitr], 0 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end + +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body: + ; accumulate im + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + + ; accumulate b + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi], s[s_tmp], v[v_in_iwi] + v_add_nc_u32 v[v_in_os], s[s_tmp+1], v[v_in_os] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi], s[s_dilation_h], v[v_in_ihi] +igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] +igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: + + s_waitcnt vmcnt(0) + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + .v_clear_nc v_gld_a, 8 + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + + + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + + + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + + + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + + + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] + + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body + +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end: + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + s_waitcnt vmcnt(0) + + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + s_mov_b32 s[s_move_slice_k_ix], 0 + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: + + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + s_waitcnt lgkmcnt(2) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + + .v_clear_nc v_c, 16 + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body +L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_end: + s_endpgm + +; LDS: (16+1) * (3 * 4) * 16 * 4 = 13056 +; k pad r3 e:4 c dword +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_128x16x16_r3 + .amdhsa_group_segment_fixed_size 13056 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 74 + .amdhsa_next_free_sgpr 52 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm new file mode 100644 index 00000000..91f5da00 --- /dev/null +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm @@ -0,0 +1,589 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_shift_pack_0, 104 + +.set s_block_ib, 2 ; bx, ho*wo +.set s_ka, 0 +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_shift_pack_0, 36 +.set s_shift_m0, 37 +.set s_shift_m1, s_shift_pack_0 + + +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 38 + +.set s_in_diff_hi, 39 +.set s_in_diff_wi, 40 +.set s_dilation_w_x, 41 +.set s_move_slice_k_ix, 42 + +.set s_hi_diff_batch_m, 43 +.set s_wi_diff_batch_m, 44 +.set s_in_os_diff_batch_m, 45 + +.set s_kitr, 1 +.set s_wei_offset, 46 +.set s_out_stride, s_wei_offset +.set s_vmcnt_out, s_magic_1 +.set s_br, 47 + +.set s_tmp, 48 +.set s_end, 54 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_a, 33 +.set v_sst_b_os, 25 +.set v_b, 26 +.set v_gld_a, 58 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+12 +.set v_wei_ix_list, v_b+15 +.set v_wei_flag, v_b+18 +.set v_wei_os, v_b+21 +.set v_tmp, v_b+24 +.set v_wei_ik, v_a +.set v_wei_ic, v_a+1 +.set v_wei_ie, v_a+2 +.set v_wei_flag_ik, v_a+3 +.set v_ib, v_a+4 +.set v_in_os, 66 +.set v_in_ihi, 67 +.set v_in_iwi, 68 +.set v_in_flag, 69 +.set v_out_os, 70 +.set v_out_flag, 71 +.set v_end, 72 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x16x16_r3 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x16x16_r3,@function +igemm_fwd_btm_nhwc_fp16_256x16x16_r3: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0 + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + s_lshl_b32 s[s_block_ib], s[s_block_ib], 8 ; 256 + v_mov_b32 v[v_ib], v0 + v_mov_b32 v[v_sld_b_os], 0 ; load + + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + + s_waitcnt lgkmcnt(0) + + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00000008 ; offset:0, width:8 + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+2], s[s_wei_offset], v[v_wei_os+1] + v_add_nc_u32 v[v_wei_ie], 4, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + .mdiv_u32_rem_vs v_wei_ix_list+2,v_wei_iy_list+2,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+2] + v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] + + ;v_cmpx_le_u32 1, v[v_wei_flag+0] + ;global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] + ;s_mov_b64 exec, -1 + ;v_cmpx_le_u32 1, v[v_wei_flag+1] + ;global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] + ;s_mov_b64 exec, -1 + ;v_cmpx_le_u32 1, v[v_wei_flag+2] + ;global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] + ;s_mov_b64 exec, -1 + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_gld_a, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + +; v_lshlrev_b32 v[v_tmp], 2, v0 +; s_lshr_b32 s[s_tmp], s2, 7 +; v_mov_b32 v[v_tmp+1], s[s_tmp] +; v_lshl_or_b32 v[v_out_os], v[v_tmp+1], 7+2, v[v_tmp] +; v_lshlrev_b32 v[v_out_os], 2, v[v_out_os] +; +; v_mov_b32 v[v_tmp+0], v0 +; v_mov_b32 v[v_tmp+1], v[v_in_os] +; v_mov_b32 v[v_tmp+2], v[v_in_flag] +; v_mov_b32 v[v_tmp+3], 0xffffccee +; +; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_tmp:v_tmp+3], s[s_p_out:s_p_out+1] +; +; s_endpgm + + ; v_cmpx_le_u32 1, v[v_in_flag] + ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + ; s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + .mdiv_u32_rem_ss s_wi_diff_batch_m, s_hi_diff_batch_m, s_stride_m, s_magic_1, s_shift_m1, s_wo, s_tmp + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + s_sub_i32 s[s_tmp+4], s[s_x], 1 + s_sub_i32 s[s_tmp+5], s[s_y], 1 + s_mul_i32 s[s_tmp+4], s[s_dilation_w], s[s_tmp+4] + s_mul_i32 s[s_tmp+5], s[s_dilation_w], s[s_tmp+5] + s_mul_i32 s[s_wi_diff_batch_m], s[s_stride_w], s[s_wi_diff_batch_m] + s_mul_i32 s[s_hi_diff_batch_m], s[s_stride_h], s[s_hi_diff_batch_m] + s_sub_i32 s[s_wi_diff_batch_m], s[s_wi_diff_batch_m], s[s_tmp+4] + s_sub_i32 s[s_hi_diff_batch_m], s[s_hi_diff_batch_m], s[s_tmp+5] + + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_mul_i32 s[s_tmp], s[s_hi_diff_batch_m], s[s_wi] + s_add_u32 s[s_tmp+1], s[s_wi_diff_batch_m], s[s_tmp] + s_mul_i32 s[s_in_os_diff_batch_m], s[s_tmp+1], s[s_in_stride_wi] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + s_lshl_b32 s[s_in_os_diff_batch_m], s[s_in_os_diff_batch_m], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+4], v[v_gld_b+5], offset0:17*4 offset1:17*5 + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+6], v[v_gld_b+7], offset0:17*6 offset1:17*7 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+8], v[v_gld_b+9], offset0:17*8 offset1:17*9 + ds_write2_b32 v[v_sst_b_os], v[v_gld_b+10], v[v_gld_b+11], offset0:17*10 offset1:17*11 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + s_cmp_gt_i32 s[s_kitr], 0 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body: + ; accumulate im + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + + ; accumulate b + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi], s[s_tmp], v[v_in_iwi] + v_add_nc_u32 v[v_in_os], s[s_tmp+1], v[v_in_os] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi], s[s_dilation_h], v[v_in_ihi] +igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] +igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_end_1: + + s_waitcnt vmcnt(0) + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + .v_clear_nc v_gld_a, 8 + ; v_cmpx_le_u32 1, v[v_in_flag] + ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + ; s_mov_b64 exec, -1 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end: + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + s_waitcnt vmcnt(0) + + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + + v_add_nc_i32 v[v_in_ihi], s[s_hi_diff_batch_m], v[v_in_ihi] + v_add_nc_u32 v[v_in_iwi], s[s_wi_diff_batch_m], v[v_in_iwi] + + .v_clear_nc v_gld_a, 4 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + + v_add_nc_u32 v[v_in_os], s[s_in_os_diff_batch_m], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + + .v_clear_nc v_gld_a+4, 4 + + ; v_cmpx_le_u32 1, v[v_in_flag] + ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + ; s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: + + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 + ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 + ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + + s_waitcnt lgkmcnt(2) + .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + ; v_cmpx_le_u32 1, v[v_out_flag] + ; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + ; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + ; s_mov_b64 exec, -1 + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end + ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end: + s_endpgm + +; LDS: (16+1) * (3 * 4) * 16 * 4 = 13056 +; k pad r3 e:4 c dword +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x16x16_r3 + .amdhsa_group_segment_fixed_size 13056 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 72 + .amdhsa_next_free_sgpr 54 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel From 68134980eca825ba5e67fdcda0901cef00a386f2 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 11 Mar 2021 19:27:03 +0800 Subject: [PATCH 30/40] pretty asm --- .../igemm_fwd_btm_nhwc_fp16_128x016.asm | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm index 9eb0104c..7276005e 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm @@ -289,7 +289,7 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_body: ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 - ; accumulate b + ;--- start move slice window s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] @@ -304,7 +304,7 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1: v_cndmask_b32 v[v_in_flag], 0, 1 v_cmp_gt_u32 s[s_hi], v[v_in_ihi] v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] -igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: + ;--- end move slice window s_waitcnt vmcnt(0) v_mov_b32 v[v_a + 0], v[v_gld_a + 0] @@ -320,15 +320,11 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 + s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 - - ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 - - - s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 @@ -339,8 +335,6 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 s_waitcnt lgkmcnt(4) - - .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 @@ -350,8 +344,6 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 s_waitcnt lgkmcnt(4) - - .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 @@ -360,7 +352,6 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 - s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 @@ -370,7 +361,6 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 - s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 @@ -390,9 +380,7 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_end_1: ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 s_waitcnt lgkmcnt(4) - - v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] - + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 @@ -426,12 +414,11 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end: s_cmp_gt_i32 s[s_batch_m], 0 s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next + ; --- start move slice for batch m ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - ; we will update v_in_os below, so use this as v_tmp .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os - v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] .v_clear_nc v_gld_a, 4 v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] @@ -453,7 +440,7 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end: global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: - + ; --- end move slice for batch m s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 @@ -548,7 +535,6 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] - v_cmpx_le_u32 1, v[v_out_flag] global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 From 6b05999861e4ee2ae7e38d972c0ac8e63e89664c Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 11 Mar 2021 19:45:51 +0800 Subject: [PATCH 31/40] tiny fix --- test/inference/build.sh | 2 +- test/inference/test_inference.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/build.sh b/test/inference/build.sh index ea241fb5..7506c183 100644 --- a/test/inference/build.sh +++ b/test/inference/build.sh @@ -4,7 +4,7 @@ ARCH=gfx1030 rm -rf out mkdir out -/opt/rocm/hip/bin/hipcc -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 +/opt/rocm/hip/bin/hipcc --amdgpu-target=$ARCH -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 /opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/ test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 /opt/rocm/hip/bin/hipcc -x hip --cuda-gpu-arch=$ARCH --cuda-device-only -c -O3 driver/gpu_naive_conv/naive_conv.cpp -o out/naive_conv.hsaco diff --git a/test/inference/test_inference.cpp b/test/inference/test_inference.cpp index 0ee64767..a6bf902e 100644 --- a/test/inference/test_inference.cpp +++ b/test/inference/test_inference.cpp @@ -450,7 +450,7 @@ int main(int argc, char **argv){ //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); //gen_rand_vector(host_weight, static_cast(k) * c * y * x, -0.5, 0.5); gen_rand_vector(host_input, static_cast(n) * c * hi * wi, -5, 5); - gen_rand_vector(host_weight, static_cast(k) * c * y * x, -2, 2); + gen_rand_vector(host_weight, static_cast(k) * c * y * x, -5, 5); //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 1, 1); //gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); From f533622c09f57ffe65f0f893dfdd574334bfbe20 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 11 Mar 2021 23:25:00 +0800 Subject: [PATCH 32/40] 256x16 tile size --- test/inference/igemm_fwd_btm_nhwc.h | 2 +- .../kernel/igemm_fwd_btm_nhwc_fp16.asm | 37 ++ .../igemm_fwd_btm_nhwc_fp16_256x016.asm | 515 +++++++++++------- 3 files changed, 367 insertions(+), 187 deletions(-) diff --git a/test/inference/igemm_fwd_btm_nhwc.h b/test/inference/igemm_fwd_btm_nhwc.h index cc96ae38..aa217e0a 100644 --- a/test/inference/igemm_fwd_btm_nhwc.h +++ b/test/inference/igemm_fwd_btm_nhwc.h @@ -121,7 +121,7 @@ typedef struct { igemm_fwd_btm_kernel_info_t igemm_fwd_btm_kernel_list [] = { {"igemm_fwd_btm_nhwc_fp16_128x16x16_r3", 128, 16, 16, 128, 3, 4}, - // {"igemm_fwd_btm_nhwc_fp16_256x16x16_r3", 256, 16, 16, 128, 3, 4} + {"igemm_fwd_btm_nhwc_fp16_256x16x16_r3", 256, 16, 16, 128, 3, 4} }; class igemm_fwd_btm_t { diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm index cfa97782..d10f7dd2 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm @@ -112,5 +112,42 @@ amdhsa.kernels: - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} - { .name: shift_pack_0, .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} - { .name: __pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_256x16x16_r3.kd + .sgpr_count: 60 + .vgpr_count: 112 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: __pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} ... .end_amdgpu_metadata diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm index 91f5da00..b24f876b 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm @@ -53,28 +53,22 @@ .set s_shift_pack_0, 36 .set s_shift_m0, 37 .set s_shift_m1, s_shift_pack_0 - - .set s_in_stride_wi, 12 .set s_in_stride_n, 13 .set s_wei_stride_k, 14 .set s_out_stride_wo, 15 .set s_out_stride_n, 38 - .set s_in_diff_hi, 39 .set s_in_diff_wi, 40 .set s_dilation_w_x, 41 .set s_move_slice_k_ix, 42 -.set s_hi_diff_batch_m, 43 -.set s_wi_diff_batch_m, 44 -.set s_in_os_diff_batch_m, 45 - .set s_kitr, 1 -.set s_wei_offset, 46 +.set s_wei_offset, 43 .set s_out_stride, s_wei_offset -.set s_vmcnt_out, s_magic_1 -.set s_br, 47 +.set s_sld_b_stride, 44 +.set s_br, 45 +.set s_ib_stride, 46 .set s_tmp, 48 .set s_end, 54 @@ -86,9 +80,9 @@ .set v_c_buf, v_c .set v_sld_b_os, 32 .set v_a, 33 -.set v_sst_b_os, 25 -.set v_b, 26 -.set v_gld_a, 58 +.set v_ib, 49 +.set v_b, 50 +.set v_gld_a, 82 .set v_gld_b, v_b .set v_wei_iy_list, v_b+12 .set v_wei_ix_list, v_b+15 @@ -99,14 +93,15 @@ .set v_wei_ic, v_a+1 .set v_wei_ie, v_a+2 .set v_wei_flag_ik, v_a+3 -.set v_ib, v_a+4 -.set v_in_os, 66 -.set v_in_ihi, 67 -.set v_in_iwi, 68 -.set v_in_flag, 69 -.set v_out_os, 70 -.set v_out_flag, 71 -.set v_end, 72 +.set v_sst_b_os, v_a+4 +.set v_in_os, 98 +.set v_in_ihi, 100 +.set v_in_iwi, 102 +.set v_in_flag, 104 +.set v_out_os, 106 +.set v_out_flag, 108 +.set v_tid, 100 +.set v_end, 112 ; short wide igemv .text @@ -120,28 +115,33 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], 17*4 * 4 ; 17dword per row, 4 row + v_and_b32 v[v_tmp+5], 7, v0 s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 v_and_b32 v[v_wei_ic], 1, v0 s_lshl_b32 s[s_block_in], s[s_block_in], 1 v_lshrrev_b32 v[v_tmp+4], 1, v0 - s_lshl_b32 s[s_block_ib], s[s_block_ib], 8 ; 256 + s_lshl_b32 s[s_block_ib], s[s_block_ib], 8 ; 256 half v_mov_b32 v[v_ib], v0 - v_mov_b32 v[v_sld_b_os], 0 ; load - - v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 - v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword - v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to 17 + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 3, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad s_waitcnt lgkmcnt(0) s_mul_i32 s[s_tmp], s[s_x], s[s_c] - s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00000008 ; offset:0, width:8 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] - s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half + s_lshl_b32 s[s_wei_offset], s[s_c], 2+1 ; 4x s_c, half s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] @@ -174,15 +174,18 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+2] v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] - ;v_cmpx_le_u32 1, v[v_wei_flag+0] - ;global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] - ;s_mov_b64 exec, -1 - ;v_cmpx_le_u32 1, v[v_wei_flag+1] - ;global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] - ;s_mov_b64 exec, -1 - ;v_cmpx_le_u32 1, v[v_wei_flag+2] - ;global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] - ;s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+2] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 + ; 17 dword per row ; calculate in offset s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] @@ -194,9 +197,11 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] .v_clear_nc v_gld_a, 4 @@ -204,6 +209,9 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] .v_clear_nc v_gld_a+4, 4 v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + v_add_nc_u32 v[v_sst_b_os+2], s[s_tmp+5], v[v_sst_b_os+1] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] v_cmp_gt_u32 s[s_hi], v[v_in_ihi] @@ -213,31 +221,38 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_gld_a+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_gld_a+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] s_mov_b32 s[s_move_slice_k_ix], 0 -; v_lshlrev_b32 v[v_tmp], 2, v0 -; s_lshr_b32 s[s_tmp], s2, 7 -; v_mov_b32 v[v_tmp+1], s[s_tmp] -; v_lshl_or_b32 v[v_out_os], v[v_tmp+1], 7+2, v[v_tmp] -; v_lshlrev_b32 v[v_out_os], 2, v[v_out_os] -; -; v_mov_b32 v[v_tmp+0], v0 -; v_mov_b32 v[v_tmp+1], v[v_in_os] -; v_mov_b32 v[v_tmp+2], v[v_in_flag] -; v_mov_b32 v[v_tmp+3], 0xffffccee -; -; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_tmp:v_tmp+3], s[s_p_out:s_p_out+1] -; -; s_endpgm - - ; v_cmpx_le_u32 1, v[v_in_flag] - ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 - ; s_mov_b64 exec, -1 - - s_mul_i32 s[s_br], s[s_wo], s[s_ho] s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] @@ -245,7 +260,7 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 - ; calculate diffs + ; calculate diffs, for y, x s_sub_i32 s[s_tmp+3], s[s_x], 1 s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] @@ -254,27 +269,12 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 - .mdiv_u32_rem_ss s_wi_diff_batch_m, s_hi_diff_batch_m, s_stride_m, s_magic_1, s_shift_m1, s_wo, s_tmp - ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h - ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - s_sub_i32 s[s_tmp+4], s[s_x], 1 - s_sub_i32 s[s_tmp+5], s[s_y], 1 - s_mul_i32 s[s_tmp+4], s[s_dilation_w], s[s_tmp+4] - s_mul_i32 s[s_tmp+5], s[s_dilation_w], s[s_tmp+5] - s_mul_i32 s[s_wi_diff_batch_m], s[s_stride_w], s[s_wi_diff_batch_m] - s_mul_i32 s[s_hi_diff_batch_m], s[s_stride_h], s[s_hi_diff_batch_m] - s_sub_i32 s[s_wi_diff_batch_m], s[s_wi_diff_batch_m], s[s_tmp+4] - s_sub_i32 s[s_hi_diff_batch_m], s[s_hi_diff_batch_m], s[s_tmp+5] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] - s_mul_i32 s[s_tmp], s[s_hi_diff_batch_m], s[s_wi] - s_add_u32 s[s_tmp+1], s[s_wi_diff_batch_m], s[s_tmp] - s_mul_i32 s[s_in_os_diff_batch_m], s[s_tmp+1], s[s_in_stride_wi] - s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 - s_lshl_b32 s[s_in_os_diff_batch_m], s[s_in_os_diff_batch_m], 1 ; output offset v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] @@ -282,55 +282,69 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] v_cndmask_b32 v[v_out_flag], 0, 1 + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], 17*8*4 + s_waitcnt vmcnt(2) v_cmpx_le_u32 1, v[v_wei_flag+0] - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:17*2 offset1:17*3 s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+1] - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+4], v[v_gld_b+5], offset0:17*4 offset1:17*5 - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+6], v[v_gld_b+7], offset0:17*6 offset1:17*7 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:17*2 offset1:17*3 s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+2] - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+8], v[v_gld_b+9], offset0:17*8 offset1:17*9 - ds_write2_b32 v[v_sst_b_os], v[v_gld_b+10], v[v_gld_b+11], offset0:17*10 offset1:17*11 + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+8], v[v_gld_b+9], offset0:17*0 offset1:17*1 + ds_write2_b32 v[v_sst_b_os+2], v[v_gld_b+10], v[v_gld_b+11], offset0:17*2 offset1:17*3 s_mov_b64 exec, -1 - .v_clear_nc v_c, 16 + .v_clear_nc v_c, 32 s_waitcnt lgkmcnt(0) s_barrier - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 s_cmp_gt_i32 s[s_kitr], 0 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body: ; accumulate im - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 - ; accumulate b + ;--- start move slice window s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] - v_add_nc_u32 v[v_in_iwi], s[s_tmp], v[v_in_iwi] - v_add_nc_u32 v[v_in_os], s[s_tmp+1], v[v_in_os] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1 s_mov_b32 s[s_move_slice_k_ix], 0 - v_add_nc_i32 v[v_in_ihi], s[s_dilation_h], v[v_in_ihi] + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1: - v_cmp_gt_u32 s[s_wi], v[v_in_iwi] - v_cndmask_b32 v[v_in_flag], 0, 1 - v_cmp_gt_u32 s[s_hi], v[v_in_ihi] - v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] -igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window s_waitcnt vmcnt(0) v_mov_b32 v[v_a + 0], v[v_gld_a + 0] @@ -341,179 +355,268 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_end_1: v_mov_b32 v[v_a + 5], v[v_gld_a + 5] v_mov_b32 v[v_a + 6], v[v_gld_a + 6] v_mov_b32 v[v_a + 7], v[v_gld_a + 7] - .v_clear_nc v_gld_a, 8 - ; v_cmpx_le_u32 1, v[v_in_flag] - ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 - ; s_mov_b64 exec, -1 + v_mov_b32 v[v_a + 8], v[v_gld_a + 8] + v_mov_b32 v[v_a + 9], v[v_gld_a + 9] + v_mov_b32 v[v_a +10], v[v_gld_a +10] + v_mov_b32 v[v_a +11], v[v_gld_a +11] + v_mov_b32 v[v_a +12], v[v_gld_a +12] + v_mov_b32 v[v_a +13], v[v_gld_a +13] + v_mov_b32 v[v_a +14], v[v_gld_a +14] + v_mov_b32 v[v_a +15], v[v_gld_a +15] + .v_clear_nc v_gld_a, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 - - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + .fma_1x8_fp16 v_c+16, v_a + 8, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + .fma_1x8_fp16 v_c+24, v_a + 8, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + .fma_1x8_fp16 v_c+16, v_a + 9, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + .fma_1x8_fp16 v_c+24, v_a + 9, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + .fma_1x8_fp16 v_c+16, v_a +10, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +10, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + .fma_1x8_fp16 v_c+16, v_a +11, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +11, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + .fma_1x8_fp16 v_c+16, v_a +12, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +12, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + .fma_1x8_fp16 v_c+16, v_a +13, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +13, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + .fma_1x8_fp16 v_c+16, v_a +14, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 s_waitcnt lgkmcnt(4) + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +14, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + .fma_1x8_fp16 v_c+16, v_a +15, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 s_waitcnt lgkmcnt(4) - .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 + .fma_1x8_fp16 v_c+24, v_a +15, v_b +24 s_sub_i32 s[s_kitr], s[s_kitr], 16 s_cmp_gt_i32 s[s_kitr], 0 s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end: - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*1 + 4*4 s_waitcnt vmcnt(0) + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_mov_b32 v[v_a + 0], v[v_gld_a + 0] + v_mov_b32 v[v_a + 1], v[v_gld_a + 1] + v_mov_b32 v[v_a + 2], v[v_gld_a + 2] + v_mov_b32 v[v_a + 3], v[v_gld_a + 3] + v_mov_b32 v[v_a + 4], v[v_gld_a + 4] + v_mov_b32 v[v_a + 5], v[v_gld_a + 5] + v_mov_b32 v[v_a + 6], v[v_gld_a + 6] + v_mov_b32 v[v_a + 7], v[v_gld_a + 7] + v_mov_b32 v[v_a + 8], v[v_gld_a + 8] + v_mov_b32 v[v_a + 9], v[v_gld_a + 9] + v_mov_b32 v[v_a +10], v[v_gld_a +10] + v_mov_b32 v[v_a +11], v[v_gld_a +11] + v_mov_b32 v[v_a +12], v[v_gld_a +12] + v_mov_b32 v[v_a +13], v[v_gld_a +13] + v_mov_b32 v[v_a +14], v[v_gld_a +14] + v_mov_b32 v[v_a +15], v[v_gld_a +15] s_cmp_gt_i32 s[s_batch_m], 0 s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next + ; --- start move slice for batch m ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w - - v_add_nc_i32 v[v_in_ihi], s[s_hi_diff_batch_m], v[v_in_ihi] - v_add_nc_u32 v[v_in_iwi], s[s_wi_diff_batch_m], v[v_in_iwi] - + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] .v_clear_nc v_gld_a, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_gld_a+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] v_cmp_gt_u32 s[s_hi], v[v_in_ihi] v_cndmask_b32 v[v_in_flag], 0, 1 - - v_add_nc_u32 v[v_in_os], s[s_in_os_diff_batch_m], v[v_in_os] + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] v_cmp_gt_u32 s[s_wi], v[v_in_iwi] - v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_gld_a+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_gld_a+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 - .v_clear_nc v_gld_a+4, 4 + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] - ; v_cmpx_le_u32 1, v[v_in_flag] - ; global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - ; global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 - ; s_mov_b64 exec, -1 -L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: + s_mov_b32 s[s_move_slice_k_ix], 0 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: + ; --- end move slice for batch m s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 0, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 + .fma_1x8_fp16 v_c+16, v_a + 8, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*1 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 0, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 + .fma_1x8_fp16 v_c+24, v_a + 8, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*2 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*2 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 1, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 + .fma_1x8_fp16 v_c+16, v_a + 9, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*2 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*2 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 1, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 + .fma_1x8_fp16 v_c+24, v_a + 9, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*3 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*3 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 2, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 + .fma_1x8_fp16 v_c+16, v_a +10, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*3 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*3 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 2, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +10, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*4 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*4 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 3, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 + .fma_1x8_fp16 v_c+16, v_a +11, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*4 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*4 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 3, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +11, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*5 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*5 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 4, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 + .fma_1x8_fp16 v_c+16, v_a +12, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*5 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*5 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 4, v_b + 8 - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +12, v_b + 8 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*6 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*6 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 5, v_b +16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 + .fma_1x8_fp16 v_c+16, v_a +13, v_b +16 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*6 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*6 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 5, v_b +24 - ; ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 - ; ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 + .fma_1x8_fp16 v_c+24, v_a +13, v_b +24 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:17*4*7 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:17*4*7 + 4*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 0, v_a + 6, v_b + 0 - ; ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 - ; ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 + .fma_1x8_fp16 v_c+16, v_a +14, v_b + 0 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:17*4*7 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:17*4*7 +12*4 s_waitcnt lgkmcnt(4) .fma_1x8_fp16 v_c+ 8, v_a + 6, v_b + 8 + .fma_1x8_fp16 v_c+24, v_a +14, v_b + 8 + v_mov_b32 v[v_sld_b_os], 0 ; reset to start s_waitcnt lgkmcnt(2) .fma_1x8_fp16 v_c+ 0, v_a + 7, v_b +16 @@ -525,6 +628,15 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + .fma_1x8_fp16 v_c+16, v_a +15, v_b +16 + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] s_waitcnt lgkmcnt(0) .fma_1x8_fp16 v_c+ 8, v_a + 7, v_b +24 v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] @@ -535,6 +647,15 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: v_cvt_f16_f32 v[v_c +13], v[v_c +13] v_cvt_f16_f32 v[v_c +14], v[v_c +14] v_cvt_f16_f32 v[v_c +15], v[v_c +15] + .fma_1x8_fp16 v_c+24, v_a +15, v_b +24 + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] @@ -546,24 +667,46 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] - ; v_cmpx_le_u32 1, v[v_out_flag] - ; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] - ; global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 - ; s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1:v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os+1:v_out_os+2], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1], offset:16 + s_mov_b64 exec, -1 - s_sub_i32 s[s_batch_m], s[s_batch_m], 1 s_cmp_le_i32 s[s_batch_m], 0 s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end - ; ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 - ; ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:17*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:17*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 - ; ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:17*4*0 + 8*4 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] v_cndmask_b32 v[v_out_flag], 0, 1 s_cmp_gt_i32 s[s_kitr], 0 - ; ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:17*4*0 +12*4 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end s_branch L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_body L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end: @@ -580,7 +723,7 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_end: .amdhsa_system_sgpr_workgroup_id_y 1 .amdhsa_system_sgpr_workgroup_id_z 1 .amdhsa_system_vgpr_workitem_id 0 - .amdhsa_next_free_vgpr 72 + .amdhsa_next_free_vgpr 112 .amdhsa_next_free_sgpr 54 .amdhsa_ieee_mode 0 .amdhsa_dx10_clamp 0 From de7c670b0aecbaa9c7dfa64e4b3f1e52cd99d6b5 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Fri, 12 Mar 2021 03:39:03 +0000 Subject: [PATCH 33/40] proper global inst --- .../igemm_fwd_btm_nhwc_fp16_128x016.asm | 22 +++++------ .../igemm_fwd_btm_nhwc_fp16_256x016.asm | 38 +++++++++---------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm index 7276005e..136a8afd 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm @@ -173,13 +173,13 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3: v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] v_cmpx_le_u32 1, v[v_wei_flag+0] - global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+1] - global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+2] - global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 @@ -217,8 +217,8 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3: v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] v_cmpx_le_u32 1, v[v_in_flag] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 s_mul_i32 s[s_br], s[s_wo], s[s_ho] @@ -317,8 +317,8 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_acc_yx_x_end_1: v_mov_b32 v[v_a + 7], v[v_gld_a + 7] .v_clear_nc v_gld_a, 8 v_cmpx_le_u32 1, v[v_in_flag] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 s_waitcnt lgkmcnt(4) @@ -436,8 +436,8 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end: s_mov_b32 s[s_move_slice_k_ix], 0 v_cmpx_le_u32 1, v[v_in_flag] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: ; --- end move slice for batch m @@ -536,8 +536,8 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_fma_end_not_load_next: v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] v_cmpx_le_u32 1, v[v_out_flag] - global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] - global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 s_mov_b64 exec, -1 s_cmp_le_i32 s[s_batch_m], 0 diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm index b24f876b..60092b31 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm +++ b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm @@ -175,13 +175,13 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_cndmask_b32 v[v_wei_flag+2], 0, v[v_wei_flag+2] v_cmpx_le_u32 1, v[v_wei_flag+0] - global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0:v_wei_os+1], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+1] - global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1:v_wei_os+2], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_wei_flag+2] - global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2:v_wei_os+3], s[s_p_wei:s_p_wei+1] + global_load_dwordx4 v[v_gld_b+8:v_gld_b+11], v[v_wei_os+2], s[s_p_wei:s_p_wei+1] s_mov_b64 exec, -1 s_mov_b32 s[s_tmp+5], 32*17*4 ; stride for wei sst offset. 8 thread for k, each thread store 4 c, hence 8*4=32 @@ -229,8 +229,8 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] v_cmpx_le_u32 1, v[v_in_flag] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] @@ -242,8 +242,8 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] v_cmpx_le_u32 1, v[v_in_flag+1] - global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 @@ -365,12 +365,12 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_acc_yx_x_end_1: v_mov_b32 v[v_a +15], v[v_gld_a +15] .v_clear_nc v_gld_a, 16 v_cmpx_le_u32 1, v[v_in_flag+0] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 v_cmpx_le_u32 1, v[v_in_flag+1] - global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 s_waitcnt lgkmcnt(4) @@ -521,8 +521,8 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end: v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] v_cmpx_le_u32 1, v[v_in_flag] - global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os:v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+0:v_gld_a+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+4:v_gld_a+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] @@ -536,8 +536,8 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end: s_mov_b32 s[s_move_slice_k_ix], 0 v_cmpx_le_u32 1, v[v_in_flag+1] - global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] - global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1:v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + global_load_dwordx4 v[v_gld_a+ 8:v_gld_a+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_gld_a+12:v_gld_a+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 s_mov_b64 exec, -1 L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: ; --- end move slice for batch m @@ -668,8 +668,8 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] v_cmpx_le_u32 1, v[v_out_flag] - global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] - global_store_dwordx4 v[v_out_os:v_out_os+1], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os], v[v_c_buf+4:v_c_buf+7], s[s_p_out:s_p_out+1], offset:16 s_mov_b64 exec, -1 v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] @@ -683,8 +683,8 @@ L_igemm_fwd_btm_nhwc_fp16_256x16x16_r3_fma_end_not_load_next: v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] v_cmpx_le_u32 1, v[v_out_flag+1] - global_store_dwordx4 v[v_out_os+1:v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] - global_store_dwordx4 v[v_out_os+1:v_out_os+2], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1], offset:16 + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1], offset:16 s_mov_b64 exec, -1 s_cmp_le_i32 s[s_batch_m], 0 From 798c0a4828d9c5bf025ffebde3ac9121d66dde52 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Thu, 18 Mar 2021 16:35:49 +0800 Subject: [PATCH 34/40] remove limitation --- igemm/algo/mfma_main_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/igemm/algo/mfma_main_loop.py b/igemm/algo/mfma_main_loop.py index 66b5e34e..57b2edbc 100644 --- a/igemm/algo/mfma_main_loop.py +++ b/igemm/algo/mfma_main_loop.py @@ -166,7 +166,7 @@ def emit_single_pass_through(self): p_interleave_gld = [ctrl.pass_through_a_interleave_gld, ctrl.pass_through_b_interleave_gld][p_idx] - assert wave_repeat_q == 2, "currently the side need LDS must have repeat 2, following limitation seems have BUG" + # assert wave_repeat_q == 2, "currently the side need LDS must have repeat 2, following limitation seems have BUG" v_pack_p = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][p_idx] v_pack_q = [ctrl.pass_through_a_v_pack, ctrl.pass_through_b_v_pack][q_idx] From 4d5787676dc5d5cb36ed9b7ea0df74396fc57530 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Mar 2021 22:04:15 +0800 Subject: [PATCH 35/40] add create_base_args --- driver/args.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/driver/args.h b/driver/args.h index 83a16d7c..4cd5ba70 100644 --- a/driver/args.h +++ b/driver/args.h @@ -158,11 +158,9 @@ class args_t { }; static inline args_t create_conv_args(int argc, char *argv[]) { - const std::string base("conv"); - const std::string base_fp16("convfp16"); - const std::string base_bf16("convbf16"); - if (argc >= 2 && argv[1] != base && argv[1] != base_fp16 && argv[1] != base_bf16) { - printf("not proper base arg name\n"); + const std::string base = create_base_args(argc, argv); + if (argc >= 2 && argv[1] != base) { + printf("not proper base arg name"); exit(1); } From d7567068e28a3762ae34ea95cedc7bf525309f31 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Mar 2021 21:52:09 +0800 Subject: [PATCH 36/40] Nhwc fwd gemmk split (#88) * add raw code to pta case * fix bug for non-pta case * add some tunables * update all configs * remove redundant print * refactor driver Co-authored-by: shaojiewang --- config/igemm_fwd_gtc_gfx908_nhwc.config | 215 ++++++++++++- driver/args.h | 20 ++ driver/conv_driver.cpp | 401 ++++++++++++++---------- driver/igemm_bwd_gtc_driver.h | 234 +++----------- driver/igemm_fwd_gtc_driver.h | 212 +++++-------- driver/igemm_gtc_base.h | 225 +++++++++++++ driver/igemm_wrw_gtc_driver.h | 189 ++--------- igemm/algo/igemm_fwd_gtc_nhwc.py | 39 ++- retieve_perf_data.py | 77 +++++ script/gtc_conv_model.sh | 329 ++++++++++--------- 10 files changed, 1131 insertions(+), 810 deletions(-) create mode 100644 retieve_perf_data.py diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 2cb3b0a5..41e3e81e 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -47,6 +47,51 @@ tensor_layout = 'nhwc' nxb = 0 nxe = 0 +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 2 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 +gemm_k_global_split = 1 + +#--------------------------- 256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 256 +gemm_n_per_block = 128 +gemm_k_per_block = 16 +wave_tile_m = 64 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +tensor_a_thread_lengths = [1, 4, 4, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0xK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + #--------------------------- 256x128 [igemm_fwd_gtc] gemm_m_per_block = 256 @@ -91,27 +136,162 @@ tensor_layout = 'nhwc' nxb = 0 nxe = 0 -#--------------------------- 128x128 +#--------------------------- 128x256 [igemm_fwd_gtc] gemm_m_per_block = 128 -gemm_n_per_block = 128 +gemm_n_per_block = 256 gemm_k_per_block = 16 wave_tile_m = 32 wave_step_m = 1 wave_repeat_m = 2 -wave_tile_n = 32 +wave_tile_n = 64 wave_step_n = 1 wave_repeat_n = 2 -wave_tile_k = 2 +wave_tile_k = 1 tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 -tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 0 + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + +#--------------------------- 128x256 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 256 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 2 +wave_tile_n = 64 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 nxe = 0 +gemm_k_global_split = 1 + +#--------------------------- 128x128 +#[igemm_fwd_gtc] +#gemm_m_per_block = 128 +#gemm_n_per_block = 128 +#gemm_k_per_block = 16 +#wave_tile_m = 32 +#wave_step_m = 1 +#wave_repeat_m = 2 +#wave_tile_n = 32 +#wave_step_n = 1 +#wave_repeat_n = 2 +#wave_tile_k = 2 +#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +#direction = "fwd" +#precision = "fp32" +#tensor_layout = 'nhwc' +#nxb = 0 +#nxe = 0 + +#--------------------------- 128x128 +#[igemm_fwd_gtc] +#gemm_m_per_block = 128 +#gemm_n_per_block = 128 +#gemm_k_per_block = 16 +#wave_tile_m = 32 +#wave_step_m = 1 +#wave_repeat_m = 2 +#wave_tile_n = 32 +#wave_step_n = 1 +#wave_repeat_n = 2 +#wave_tile_k = 2 +#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +#direction = "fwd" +#precision = "fp32" +#tensor_layout = 'nhwc' +#nxb = 0 +#nxe = 1 + +#--------------------------- 128x128 +#[igemm_fwd_gtc] +#gemm_m_per_block = 128 +#gemm_n_per_block = 128 +#gemm_k_per_block = 16 +#wave_tile_m = 32 +#wave_step_m = 1 +#wave_repeat_m = 2 +#wave_tile_n = 32 +#wave_step_n = 1 +#wave_repeat_n = 2 +#wave_tile_k = 2 +#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 +#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +#direction = "fwd" +#precision = "fp32" +#tensor_layout = 'nhwc' +#nxb = 0 +#nxe = 0 +#gemm_k_global_split = 1 #--------------------------- 128x128 [igemm_fwd_gtc] @@ -134,8 +314,31 @@ precision = "fp32" tensor_layout = 'nhwc' nxb = 0 nxe = 1 +gemm_k_global_split = 1 - +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_pass_through = 1 +tensor_a_thread_lengths = [1, 8, 1, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 #--------------------------- 64x128 [igemm_fwd_gtc] diff --git a/driver/args.h b/driver/args.h index 4cd5ba70..6114fe97 100644 --- a/driver/args.h +++ b/driver/args.h @@ -157,6 +157,26 @@ class args_t { std::unordered_map input_map; }; +static inline std::string create_base_args(int argc, char *argv[]) { + if(argc < 2) + { + printf("Invalid Number of Input Arguments\n"); + exit(0); + } + + std::string arg = argv[1]; + + if(arg != "conv" && arg != "convfp16" && arg != "convint8" && arg != "--version") + { + printf("Invalid Base Input Argument\n"); + exit(0); + } + else if(arg == "-h" || arg == "--help" || arg == "-?") + exit(0); + else + return arg; +} + static inline args_t create_conv_args(int argc, char *argv[]) { const std::string base = create_base_args(argc, argv); if (argc >= 2 && argv[1] != base) { diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index fb489eb4..4f592ff3 100755 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -105,10 +105,11 @@ static int next_pow2(int n) { return n << 1; } typedef struct { - int return_code; - float duration_ms; - float gflops; - float efficiency; + int return_code {-1}; + int gks {0}; // this is to store the gks value after benchmarked + float duration_ms {FLT_MAX}; + float gflops {0}; + float efficiency {0}; std::string kernel_name; } result_t; @@ -122,7 +123,12 @@ typedef struct { } \ } while (0) -static inline double theoritical_fp32_gflops(double sclk_ghz, size_t cu, +#include "igemm_gtc_base.h" +#include "igemm_fwd_gtc_driver.h" +#include "igemm_bwd_gtc_driver.h" +#include "igemm_wrw_gtc_driver.h" + +static inline double theoritical_gflops(double sclk_ghz, size_t cu, size_t simd) { return 2 * sclk_ghz * cu * simd; } @@ -148,10 +154,63 @@ measured_fp32_conv_gflops(double time_ms, size_t n, size_t c, size_t hi, return flop / (time_ms * 1e6); } -#include "igemm_gtc_base.h" -#include "igemm_fwd_gtc_driver.h" -#include "igemm_bwd_gtc_driver.h" -#include "igemm_wrw_gtc_driver.h" + +static inline double get_theoritical_conv_flop(const args_t * conv_args) +{ + int hi = conv_args->get_int("in_h"); + int wi = conv_args->get_int("in_w"); + int n = conv_args->get_int("batchsize"); + int k = conv_args->get_int("out_channels"); + int c = conv_args->get_int("in_channels"); + + int stride_h = conv_args->get_int("conv_stride_h"); + int stride_w = conv_args->get_int("conv_stride_w"); + int dilation_h = conv_args->get_int("dilation_h"); + int dilation_w = conv_args->get_int("dilation_w"); + int pad_h = conv_args->get_int("pad_h"); + int pad_w = conv_args->get_int("pad_w"); + int y = conv_args->get_int("fil_h"); + int x = conv_args->get_int("fil_w"); + int ngroups = conv_args->get_int("group_count"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + return theoritical_fp32_conv_flop(n, c, hi, wi, k, y, x, stride_h, stride_w, + dilation_h, dilation_w, pad_h, pad_w, ngroups); +} + +static inline double get_theoritical_gpu_gflops(int sclk_mhz, driverDataType_t data_type) +{ + int num_cu; + int gcn_arch = 0; + int num_simd = 4 * 16; + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + gcn_arch = dev_prop.gcnArch; + if(gcn_arch >= 1000) + num_cu *= 2; + + int fp_factor = 1; + if(data_type == driverHalf){ + if(gcn_arch == 908) + fp_factor = 4; // xdlops + else + fp_factor = 2; // dlops + } + // else if(data_type == driverInt8){ + // if(gcn_arch == 908) + // fp_factor = 4; + // } + + if(gcn_arch == 908){ + num_simd = 4 * 32 ; // 4x miSIMD, 32x mac unit + } + + return theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * fp_factor); +} #ifndef ABS #define ABS(x) ((x) > 0 ? (x) : -1 * (x)) @@ -317,6 +376,17 @@ static inline double get_wrw_nrms() #endif } +static inline double get_nrms(std::string direction) +{ + if(direction == "fwd") + return get_fwd_nrms(); + if(direction == "bwd") + return get_bwd_nrms(); + if(direction == "wrw") + return get_wrw_nrms(); + assert(0); +} + void dump_arg(const args_t *arg) { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -341,16 +411,105 @@ void dump_arg(const args_t *arg) { pad_h, pad_w, ho, wo); } + +template +void launch_conv_driver(driver_t * driver, const args_t *conv_args, const std::vector & tunables, std::string direction, + void* device_input, void* device_weight, void* device_output, + pre_func_t && pre_func, post_func_t && post_func) +{ + int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ); + std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT); + int log_fastest_config = env_get_int("IGEMM_LOG_FASTEST_CONFIG", 0); + + double theo_conv_flop = get_theoritical_conv_flop(conv_args); + double theo_gpu_gflops = get_theoritical_gpu_gflops(sclk_mhz, driver->data_type); + + auto launch = [&](const igemm_gtc_tunable_t * tunable, int index) -> result_t { + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT){ + if(run_only_kernel != driver->get_kernel_name(tunable)){ + return result_t{}; + } + } + + printf("[%s:%2d] %s", direction.c_str(), index, driver->get_kernel_name(tunable).c_str()); + fflush(stdout); + + pre_func(); + + result_t result = driver->run(conv_args, tunable, device_input, device_weight, device_output); + + std::string gks_string = ""; + if(tunable->gemm_k_global_split){ + gks_string = "[" + std::to_string(result.gks) + "]"; + } + printf("%s, ", gks_string.c_str()); + + if (result.return_code != 0){ + printf("not applicatble\n"); + return result_t{}; + } + + double gflops = theo_conv_flop / (result.duration_ms * 1e6); + printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, + gflops / 1000 , (gflops / theo_gpu_gflops) * 100); + + post_func(); + + printf("\n"); + result.gflops = gflops; + result.efficiency = (gflops / theo_gpu_gflops) * 100; + return result; + }; + + if(driver->driver_mode == driver_mode_normal){ + result_t fastest_result_fwd; + fastest_result_fwd.duration_ms = FLT_MAX; + int fastest_id = -1; + + for(int i=0; idriver_mode == driver_mode_heuristic){ + igemm_gtc_tunable_t selected_tunable = driver->heuristic_select_kernel(conv_args); + if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) + if(run_only_kernel != driver->get_kernel_name(&selected_tunable)){ + printf("heuristic selected tunable not match your request\n"); + return; + } + + result_t result = launch(&selected_tunable, 0); + }else{ + assert(0); + } +} + int main(int argc, char **argv) { char *hsaco = env_get_str("IGEMM_HSACO", IGEMM_HSACO); char *config_file = env_get_str("IGEMM_CONFIG_FILE", IGEMM_CONFIG_FILE); std::string run_only_kernel = env_get_str("IGEMM_RUN_ONLY_KERNEL", IGEMM_RUN_ONLY_KERNEL_DEFAULT); int warmup = env_get_int("IGEMM_WARMUP", WARMUP); int repeat = env_get_int("IGEMM_REPEAT", REPEAT); - int sclk_mhz = env_get_int("IGEMM_SCLK_MHZ", SCLK_MHZ); - int log_fastest_config = env_get_int("IGEMM_LOG_FASTEST_CONFIG", 0); - int wrw_kernel_selection = env_get_int("IGEMM_LOG_SELECTED_CONFIG", 0); int assert_when_invalid = env_get_int("IGEMM_ASSERT_WHEN_INVALID", 0); + int verbose = env_get_int("IGEMM_VERBOSE", 0); + driver_mode_t driver_mode = static_cast(env_get_int("IGEMM_MODE", 0)); config_parser_t config_parser(config_file); auto content = config_parser.parse(); //content.dump(); @@ -365,13 +524,31 @@ int main(int argc, char **argv) { printf("no tunable specified, may not work\n"); return 0; } - // printf("tunables:%d\n", tunables.size()); + // printf("tunables:%d, hsaco:%s\n", tunables.size(), hsaco); hipModule_t module; HIP_CALL(hipModuleLoad(&module, hsaco)); + std::string base_arg = create_base_args(argc, argv); args_t conv_args = create_conv_args(argc, argv); // dump_arg(&conv_args); + driverDataType_t driver_data_type; + + if(base_arg == "conv"){ + driver_data_type = driverFloat; + } + else if(base_arg == "convfp16"){ + driver_data_type = driverHalf; + } + else if(base_arg == "convbf16") { + driver_data_type = driverBFloat16; + exit(0); + } + else if(base_arg == "convint8") { + driver_data_type = driverInt8; + } + else + exit(0); int hi = conv_args.get_int("in_h"); int wi = conv_args.get_int("in_w"); @@ -419,37 +596,6 @@ int main(int argc, char **argv) { int need_verify = conv_args.get_int("verify"); - // printf("fwd:%d, bwd:%d, wrw:%d, verify:%d\n",need_fwd, need_bwd, need_wrw, need_verify); - - int num_cu; - int num_simd = 64; // hard coded - int gcn_arch = 0; - { - hipDeviceProp_t dev_prop; - hipDevice_t dev; - HIP_CALL(hipGetDevice(&dev)); - HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); - num_cu = dev_prop.multiProcessorCount; - gcn_arch = dev_prop.gcnArch; -#if 0 -#define P_DEVICE_PROP_INT(prop) \ - printf(#prop":%d\n", dev_prop.prop) - - - P_DEVICE_PROP_INT(clockRate); - P_DEVICE_PROP_INT(memoryClockRate); - P_DEVICE_PROP_INT(memoryBusWidth); - P_DEVICE_PROP_INT(major); - P_DEVICE_PROP_INT(minor); - P_DEVICE_PROP_INT(gcnArch); -#endif - } - if(gcn_arch == 908){ - num_simd = 4 * 32 ; // 4x miSIMD, 32x mac unit - } - double fp32_gflops = - theoritical_fp32_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd); - if (need_fwd){ result_t fastest_result_fwd; fastest_result_fwd.duration_ms = FLT_MAX; @@ -505,34 +651,16 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); - igemm_fwd_gtc_t conv_fwd_driver; - double nrms = get_fwd_nrms(); - for (int i = 0; i < tunables.size(); i++) { - igemm_gtc_tunable_t *tunable = &tunables[i]; - if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) - if(run_only_kernel != conv_fwd_driver.get_kernel_name(tunable)) - continue; - - printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(tunable).c_str()); - fflush(stdout); + igemm_fwd_gtc_t conv_fwd_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); + auto fwd_pre = [&](){ if (need_verify) HIP_CALL(hipMemset(device_output, 0, static_cast(n) * k * ho * wo * sizeof(float))); + }; - result_t result = - conv_fwd_driver.run(&conv_args, tunable, module, device_input, - device_weight, device_output, warmup, repeat); - if (result.return_code != 0){ - printf("not applicatble\n"); - continue; - } - - double gflops = measured_fp32_conv_gflops( - result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, - dilation_h, dilation_w, pad_h, pad_w, ngroups); - printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, - gflops / 1000 , (gflops / fp32_gflops) * 100); + auto fwd_post = [&](){ if (need_verify) { + double nrms = get_fwd_nrms(); HIP_CALL(hipMemcpy(device_output_to_host, device_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); @@ -541,26 +669,10 @@ int main(int argc, char **argv) { printf(", valid:%s", is_valid ? "y" : "n"); if(assert_when_invalid) assert(is_valid); } - printf("\n"); - if(result.duration_ms < fastest_result_fwd.duration_ms){ - fastest_result_fwd = result; - fastest_result_fwd.gflops = (float)gflops; - fastest_result_fwd.efficiency = (gflops / fp32_gflops) * 100; - fastest_id = i; - } - } - if(log_fastest_config){ - dump_arg(&conv_args); - if(fastest_id == -1) - printf(" fastest: no suitable kernel\n"); - else - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", - fastest_id, - fastest_result_fwd.kernel_name.c_str(), - fastest_result_fwd.duration_ms, - fastest_result_fwd.gflops / 1000, - fastest_result_fwd.efficiency); - } + }; + + launch_conv_driver(&conv_fwd_driver, &conv_args, tunables, "fwd", device_input, device_weight, device_output, fwd_pre, fwd_post); + if (need_verify) free(device_output_to_host); } @@ -619,35 +731,16 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + igemm_bwd_gtc_t conv_bwd_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); - igemm_bwd_gtc_t conv_bwd_driver; - double nrms = get_bwd_nrms(); - for (int i = 0; i < tunables.size(); i++) { - igemm_gtc_tunable_t *tunable = &tunables[i]; - if(run_only_kernel != IGEMM_RUN_ONLY_KERNEL_DEFAULT) - if(run_only_kernel != conv_bwd_driver.get_kernel_name(tunable)) - continue; - - printf("[bwd:%2d] %s, ", i, conv_bwd_driver.get_kernel_name(tunable).c_str()); - fflush(stdout); - + auto bwd_pre = [&](){ if (need_verify) - HIP_CALL(hipMemset(device_input, 0x7f, - static_cast(n) * c * hi * wi * sizeof(float))); // 0x7f7f7f7f ~= 7.41e+28, a very large number - result_t result = - conv_bwd_driver.run(&conv_args, tunable, module, device_input, - device_weight, device_output, warmup, repeat); - if (result.return_code != 0){ - printf("not applicatble\n"); - continue; - } + HIP_CALL(hipMemset(device_input, 0x7f, static_cast(n) * c * hi * wi * sizeof(float))); // 0x7f7f7f7f ~= 7.41e+28, a very large number + }; - double gflops = measured_fp32_conv_gflops( - result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, - dilation_h, dilation_w, pad_h, pad_w, ngroups); - printf("cost:%.3fms, tflops:%.3f(%.2f%%)", result.duration_ms, - gflops / 1000 , (gflops / fp32_gflops) * 100); + auto bwd_post = [&](){ if (need_verify) { + double nrms = get_bwd_nrms(); HIP_CALL(hipMemcpy(device_input_to_host, device_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyDeviceToHost)); @@ -655,34 +748,15 @@ int main(int argc, char **argv) { static_cast(n) * c * hi * wi, nrms); printf(", valid:%s", is_valid ? "y" : "n"); if(assert_when_invalid) assert(is_valid); - // if (!is_valid) { - // printf("\n"); - // break; - // } - } - printf("\n"); - if(result.duration_ms < fastest_result_bwd.duration_ms){ - fastest_result_bwd = result; - fastest_result_bwd.gflops = (float)gflops; - fastest_result_bwd.efficiency = (gflops / fp32_gflops) * 100; - fastest_id = i; } - } - if(log_fastest_config){ - dump_arg(&conv_args); - if(fastest_id == -1) - printf(" fastest: no suitable kernel\n"); - else - printf(" fastest: [%d]%s, cost:%.3fms, tflops:%.3f(%.2f%%)\n", - fastest_id, - fastest_result_bwd.kernel_name.c_str(), - fastest_result_bwd.duration_ms, - fastest_result_bwd.gflops / 1000, - fastest_result_bwd.efficiency); - } + }; + + launch_conv_driver(&conv_bwd_driver, &conv_args, tunables, "bwd", device_input, device_weight, device_output, bwd_pre, bwd_post); + if (need_verify) free(device_input_to_host); } + if (need_wrw){ float *device_weight_to_host = NULL; if (need_verify) { @@ -733,38 +807,34 @@ int main(int argc, char **argv) { HIP_CALL(hipMemcpy(device_output, host_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyHostToDevice)); -#if 0 - printf("input\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check*hi*wi, host_input[i_check*hi*wi], ((int *)host_input)[i_check*hi*wi]); - } - printf("output\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check*ho*wo, host_output[i_check*ho*wo], ((int *)host_output)[i_check*ho*wo]); - } - printf("input\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, host_input[i_check], ((int *)host_input)[i_check]); - } - printf("output\r\n"); - for (int i_check = 0; i_check < (0+32); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, host_output[i_check], ((int *)host_output)[i_check]); - } - printf("workspace debug end \r\n"); -#endif - igemm_wrw_gtc_t conv_wrw_driver; - float min_duration = 10000000.0f; - float selected_duration = 10000000.0f; - double nrms = get_wrw_nrms(); - std::string kernel_name; - std::string selected_kernel; + igemm_wrw_gtc_t conv_wrw_driver(module, driver_mode, driver_data_type, warmup, repeat, verbose); + + auto wrw_pre = [&](){ + if (need_verify) + HIP_CALL(hipMemset(device_weight, 0, static_cast(k) * c * y * x * sizeof(float))); + }; + auto wrw_post = [&](){ + if (need_verify) { + double nrms = get_wrw_nrms(); + HIP_CALL(hipMemcpy(device_weight_to_host, device_weight, + static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * sizeof(float), + hipMemcpyDeviceToHost)); + bool is_valid = valid_vector(host_weight, device_weight_to_host, + static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x, nrms); + printf(", valid:%s", is_valid ? "y" : "n"); + if(assert_when_invalid) assert(is_valid); + } + }; + + launch_conv_driver(&conv_wrw_driver, &conv_args, tunables, "wrw", device_input, device_weight, device_output, wrw_pre, wrw_post); + + if (need_verify) + free(device_weight_to_host); + +#if 0 selected_kernel = conv_wrw_driver.select_kernel(&conv_args, tunables); int min_grid = 0; @@ -846,6 +916,7 @@ int main(int argc, char **argv) { } if (need_verify) free(device_weight_to_host); +#endif } free(host_input); diff --git a/driver/igemm_bwd_gtc_driver.h b/driver/igemm_bwd_gtc_driver.h index e685482d..52bca0d4 100755 --- a/driver/igemm_bwd_gtc_driver.h +++ b/driver/igemm_bwd_gtc_driver.h @@ -40,9 +40,9 @@ // #define IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL 1 typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -163,14 +163,13 @@ static void dump_bwd_karg(igemm_bwd_gtc_karg_t * karg){ std::cout<fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -179,10 +178,9 @@ class igemm_bwd_gtc_t { int waves_per_n = tunable->gemm_n_per_block / (tunable->wave_tile_n * tunable->wave_step_n * tunable->wave_repeat_n); return waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE; } - } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -254,7 +252,7 @@ class igemm_bwd_gtc_t { } bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -363,8 +361,7 @@ class igemm_bwd_gtc_t { } result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; @@ -498,8 +495,8 @@ class igemm_bwd_gtc_t { if(y < stride_h || x < stride_w || dilation_h != 1 || dilation_w != 1) need_set_zero = true; - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); + size_t grid_size = get_grid_size(arg, tunable); #ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL igemm_upsampling_clear_karg_t ukarg; @@ -547,36 +544,15 @@ class igemm_bwd_gtc_t { HIP_CALL( hipModuleGetFunction(&upsampling_clear_kernel_func, module, upsampling_clear_kernel_name.c_str())); #endif + result_t result; - auto launch_bwd = [&]() -> float{ - float ms_total = .0; - if(need_set_zero){ - float ms = .0; - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); -#ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &ukarg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &ukarg_size, - HIP_LAUNCH_PARAM_END}; - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, - u_block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); -#else - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord( start, NULL )); - hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); - HIP_CALL(hipEventRecord( stop, NULL )); -#endif - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); - - ms_total += ms; - } - + if(tunable->multihead){ + // TODO: + }else{ + std::vector kernels; + std::vector kargs; + kargs.reserve(num_of_gemm); // CAUSION! we do not want this vector to be reallocated and move to other place + int valid_kernel_index = 0; for(int gemm_id = 0; gemm_id < num_of_gemm; gemm_id++){ int i_y_tilda = gemm_id / x_tilda; int i_x_tilda = gemm_id % x_tilda; @@ -585,154 +561,46 @@ class igemm_bwd_gtc_t { int gemm_k = (k / group) * y_dot_slice * x_dot_slice; bool is_gemm_not_empty = gemm_k > 0 && y_dot_slice > 0 && x_dot_slice > 0; - - karg.dtile_iy = i_y_tilda; - karg.dtile_ix = i_x_tilda; - karg.dslice_y = y_dot_slice; - karg.dslice_x = x_dot_slice; + if(is_gemm_not_empty){ + kargs.push_back(karg); + kargs[valid_kernel_index].dtile_iy = i_y_tilda; + kargs[valid_kernel_index].dtile_ix = i_x_tilda; + kargs[valid_kernel_index].dslice_y = y_dot_slice; + kargs[valid_kernel_index].dslice_x = x_dot_slice; #if USE_MAGIC_DIV - magic_div_u32_t mdiv_0 = is_gemm_not_empty ? magic_div_u32_gen(y_dot_slice * x_dot_slice) : magic_div_u32_t({0, 0}); - magic_div_u32_t mdiv_1 = is_gemm_not_empty ? magic_div_u32_gen(x_dot_slice) : magic_div_u32_t({0, 0}); - karg.magic_0 = mdiv_0.magic; - karg.magic_1 = mdiv_1.magic; + magic_div_u32_t mdiv_0 = is_gemm_not_empty ? magic_div_u32_gen(y_dot_slice * x_dot_slice) : magic_div_u32_t({0, 0}); + magic_div_u32_t mdiv_1 = is_gemm_not_empty ? magic_div_u32_gen(x_dot_slice) : magic_div_u32_t({0, 0}); + kargs[valid_kernel_index].magic_0 = mdiv_0.magic; + kargs[valid_kernel_index].magic_1 = mdiv_1.magic; - karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); - karg.shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); + kargs[valid_kernel_index].shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, mdiv_3.shift); + kargs[valid_kernel_index].shift_pack_1 = magic_div_u32_pack_shift(mdiv_4.shift, mdiv_5.shift, mdiv_6.shift, 0); #endif - // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); - // dump_bwd_karg(&karg); + // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); + // dump_bwd_karg(&kargs[valid_kernel_index]); - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; + kernels.push_back({kernel_func, (void*)&kargs[valid_kernel_index], karg_size, std::vector{grid_size * block_size, 1, 1}, std::vector{block_size, 1, 1}}); - if(is_gemm_not_empty){ -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - timer.stop(); - ms = timer.duration(); -#endif + valid_kernel_index++; } - ms_total += ms; } - return ms_total; - }; - - auto launch_bwd_multihead = [&]() -> float{ - float ms_total = .0; - if(need_set_zero){ - float ms = .0; - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); -#ifdef IGEMM_BWD_UPSAMPLING_USE_CUSTOM_KERNEL - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &ukarg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &ukarg_size, - HIP_LAUNCH_PARAM_END}; - HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1, - u_block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); -#else - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord( start, NULL )); - hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); - HIP_CALL(hipEventRecord( stop, NULL )); -#endif - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); - ms_total += ms; - } - // if 1x1 and stride/dilation > 1, will have empty gemms which will waste launch grid. better ignore that case at runtime - int origin_grid_size = grid_size/num_of_gemm; - karg.dtile_iy = origin_grid_size; - karg.dtile_ix = x_dot | (y_dot<<16); - karg.dslice_y = y % y_dot; - karg.dslice_x = x % x_dot; - // printf("start launch id:%d(%d), block:%d, grid:%d\n", gemm_id, is_gemm_not_empty?1:0, block_size, grid_size); - // dump_bwd_karg(&karg); - - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - timer.stop(); - ms = timer.duration(); -#endif - ms_total += ms; - return ms_total; - }; - - auto launch_bwd_driver = [&](){ - if(tunable->multihead) - return launch_bwd_multihead(); - else - return launch_bwd(); - }; - - for (int i = 0; i < warmup; i++) { - launch_bwd_driver(); - } - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_bwd_driver(); - duration_list.push_back(d); + assert(kernels.size() == valid_kernel_index && kargs.size() == valid_kernel_index); + auto bwd_epilog = need_set_zero ? + std::function{[&]() -> float{ + hipMemset(p_in, 0, n*c*hi*wi*sizeof(float)); + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; + float ms = igemm_launch_kernels_with_epilog(kernels, bwd_epilog, warmup, repeat); + + result.return_code = 0; + result.duration_ms = ms; } - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); - usleep(1000 * 5); - - result_t result; - result.return_code = 0; - result.duration_ms = avg_duration; - result.kernel_name = kernel_name; return result; } }; diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index fc7f92b9..99f1758c 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -36,11 +36,15 @@ #include #include #include +#include + +//#define GEMM_K_GLOBAL_SPLIT 3 +#define MAX_GEMM_K_SPLITS 8 typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -67,14 +71,14 @@ typedef struct { uint32_t magic_6; // denom: n*b*k / (m_per_block*n_per_block) uint32_t shift_pack_0; uint32_t shift_pack_1; - uint32_t __pack_0; + uint32_t ks; #endif } __attribute__((packed)) igemm_fwd_gtc_karg_t; typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -97,7 +101,7 @@ typedef struct { uint32_t magic_2; // denom: wo uint32_t magic_3; // denom: (gemm_m/m_per_block) * (gemm_n/n_per_block) uint32_t shift_pack_0; - uint32_t __pack_0; + uint32_t ks; #endif } __attribute__((packed)) igemm_fwd_gtc_nhwc_karg_t; @@ -137,14 +141,13 @@ static void dump_fwd_karg(igemm_fwd_gtc_karg_t * karg){ std::cout<fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -154,8 +157,9 @@ class igemm_fwd_gtc_t { return waves_per_m * waves_per_n * AMDGPU_WAVE_SIZE; } } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + // return grid size without consideration of split k + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -174,7 +178,7 @@ class igemm_fwd_gtc_t { int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); int group = arg->get_int("group_count"); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; @@ -182,6 +186,7 @@ class igemm_fwd_gtc_t { int nxe = tunable->nxe; int nxb = tunable->nxb; int b = ho * wo; + if(tunable->tensor_layout == "nchw") b = nxe == 0 ? (ho * wo) : ((ho * wo + nxb - 1) / nxb) * nxb; // pad to nxb modulo when nxe != 0 @@ -204,56 +209,8 @@ class igemm_fwd_gtc_t { return grid_size; } - // this is to support big tensor > 4G. need to decide how many splits needed - // return the number of splits - int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) - { - int hi = arg->get_int("in_h"); - int wi = arg->get_int("in_w"); - int n = arg->get_int("batchsize"); - int k = arg->get_int("out_channels"); - int c = arg->get_int("in_channels"); - - int stride_h = arg->get_int("conv_stride_h"); - int stride_w = arg->get_int("conv_stride_w"); - int dilation_h = arg->get_int("dilation_h"); - int dilation_w = arg->get_int("dilation_w"); - int pad_h = arg->get_int("pad_h"); - int pad_w = arg->get_int("pad_w"); - int y = arg->get_int("fil_h"); - int x = arg->get_int("fil_w"); - int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); - int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); - - int data_byte = utility_string_to_data_byte(tunable->precision); - size_t image_size_input = static_cast(c) * hi * wi * data_byte; - size_t image_size_output = static_cast(k) * ho * wo * data_byte; - size_t size_4g = 0xffffffffUL; - if(image_size_input >= size_4g || image_size_output >= size_4g) - return 0; - - size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; - size_t splited_n = size_4g / image_size; - - // round up splits, we must match - // 1. splited_n * image_size < size_4g - // 2. n % splited_n == 0 - // if(splited_n >= n) - // return 1; - assert(splited_n != 0); - while(splited_n >= 1){ - // printf("n:%d, splited_n:%d\n", n, splited_n); - if(n % splited_n == 0) - break; - splited_n--; - } - - assert(splited_n * image_size < size_4g && n % splited_n == 0); - return n / splited_n; - } - bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); @@ -275,7 +232,7 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); if(splits == 0){ printf("image size (c*h*w) is bigger than 4g, which is not supported now\n"); return false; @@ -357,7 +314,7 @@ class igemm_fwd_gtc_t { // return false; //} - if((c / group) % gemm_k_per_block != 0) + if(((c >> tunable->gemm_k_global_split) / group) % gemm_k_per_block != 0) return false; // if(gemm_m_per_block % tunable->nxb != 0){ @@ -406,9 +363,7 @@ class igemm_fwd_gtc_t { return true; } - result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; @@ -436,7 +391,7 @@ class igemm_fwd_gtc_t { assert(c % group == 0 && k % group == 0); - int splits = split_batch_size(arg, tunable); + size_t splits = igemm_split_batch_size(arg, utility_string_to_data_byte(tunable->precision)); n = n/splits; // split batch size here int gemm_m_per_block = tunable->gemm_m_per_block; @@ -548,76 +503,65 @@ class igemm_fwd_gtc_t { assert(0); } - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); - // printf("kernel:%s\n, block:%d, grid:%d\n", kernel_name.c_str(), block_size, grid_size); - HIP_CALL( - hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); - - auto launch_fwd = [&]() -> float { - // printf("launch fwd block:%d, grid:%dx%d\n", block_size, grid_size, splits); - // dump_fwd_karg(&karg); - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, static_cast(&karg_buffer[0]), - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; - -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - - timer.stop(); - ms = timer.duration(); -#endif - return ms; - }; + HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_name.c_str())); + + // TODO: use kernel to pre-clear when atomic + auto fwd_epilog = tunable->gemm_k_global_split ? + std::function{[&]() -> float{ + hipMemset(p_out, 0, static_cast(n) * k * ho * wo * sizeof(float)); + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; - for (int i = 0; i < warmup; i++) { - launch_fwd(); - } + result_t result; + result.kernel_name = kernel_name; + if(this->driver_mode == driver_mode_normal){ + float min_duration = FLT_MAX; + int selected_gks = 0; + int max_split_num = tunable->gemm_k_global_split == 0 ? + 0 : igemm_get_max_gks(c, tunable->gemm_k_per_block, MAX_GEMM_K_SPLITS); + for(int gks = 0; gks <= max_split_num; gks++){ + size_t grid_size = get_grid_size(arg, tunable) * (1 << gks); + if(tunable->tensor_layout == "nhwc"){ + // This is hacky, but in MIOpen we prefer a heuristic way to set gks, so ok now. + igemm_fwd_gtc_nhwc_karg_t *karg_revalue = (igemm_fwd_gtc_nhwc_karg_t *)(karg_buffer); + karg_revalue->ks = gks; + } + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}} + }, fwd_epilog, this->warmup, this->repeat); + + if(min_duration > duration){ + min_duration = duration; + selected_gks = gks; + } + } - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_fwd(); - duration_list.push_back(d); - } + result.return_code = 0; + result.duration_ms = min_duration; + result.gks = selected_gks; + }else if(this->driver_mode == driver_mode_heuristic){ + int gks = heuristic_select_gks(arg, tunable); + size_t grid_size = get_grid_size(arg, tunable) * (1 << gks); - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}} + }, fwd_epilog, this->warmup, this->repeat); - usleep(1000 * 5); + result.return_code = 0; + result.duration_ms = duration; + result.gks = gks; + }else{ + assert(0); + } - result_t result; - result.return_code = 0; - result.duration_ms = avg_duration; - result.kernel_name = kernel_name; + usleep(1000 * 5); return result; } }; diff --git a/driver/igemm_gtc_base.h b/driver/igemm_gtc_base.h index c3b4bfac..0409ec8c 100755 --- a/driver/igemm_gtc_base.h +++ b/driver/igemm_gtc_base.h @@ -34,6 +34,8 @@ #include #include #include +#include +#include #define IGEMM_GTC_TUNABLE_FMA_TYPE_MAC "mac" #define IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS "dlops" @@ -41,6 +43,19 @@ #define IGEMM_GTC_TUNABLE_FMA_TYPE_NA "fma_na" #define AMDGPU_WAVE_SIZE 64 +typedef enum { + driverHalf = 0, /*!< 16-bit floating point (Fully supported) */ + driverFloat = 1, /*!< 32-bit floating point (Fully supported) */ + driverInt8 = 3, + driverBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction) + (Partially supported) */ +} driverDataType_t; + +typedef enum { + driver_mode_normal = 0, // bench all solutions + driver_mode_heuristic = 1, // find suitable heuristic +} driver_mode_t; + #if USE_MAGIC_DIV typedef struct { uint32_t magic; @@ -320,4 +335,214 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) { return kernel_name; } +static inline float igemm_launch_kernel_single(hipFunction_t kernel_func, void* args, size_t arg_size, std::vector grid_size, std::vector block_size) +{ + void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, args, + HIP_LAUNCH_PARAM_BUFFER_SIZE, &arg_size, + HIP_LAUNCH_PARAM_END}; + float ms = .0; + + hipEvent_t start; + hipEvent_t stop; + hipEventCreate(&start); + hipEventCreate(&stop); + + // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem + HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size[0], grid_size[1], grid_size[2], + block_size[0], block_size[1], block_size[2], 0, 0, NULL, + (void **)&config, start, stop)); + + + hipEventSynchronize(stop); + hipEventElapsedTime(&ms, start, stop); + hipEventDestroy(start); + hipEventDestroy(stop); + + return ms; +} + +static inline float igemm_launch_kernel(hipFunction_t kernel_func, void* args, size_t arg_size, std::vector grid_size, std::vector block_size, int warmup, int repeat) +{ + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + igemm_launch_kernel_single(kernel_func, args, arg_size, grid_size, block_size); + } + + for (int i = 0; i < repeat; i++) { + float d = igemm_launch_kernel_single(kernel_func, args, arg_size, grid_size, block_size); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} + +typedef struct{ + hipFunction_t kernel_func; + void * args; + size_t arg_size; + std::vector grid_size; + std::vector block_size; +}igemm_launch_kernel_t; +static inline float igemm_launch_kernels(const std::vector & kernels, int warmup, int repeat) +{ + auto launch_kernels = [&]() -> float{ + float ms = .0; + for(const auto ker : kernels) + ms += igemm_launch_kernel_single(ker.kernel_func, ker.args, ker.arg_size, ker.grid_size, ker.block_size); + return ms; + }; + + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + launch_kernels(); + } + + for (int i = 0; i < repeat; i++) { + float d = launch_kernels(); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} +template +static inline float igemm_launch_kernels_with_epilog(const std::vector & kernels, epilog_kernel_t epilog_kernel, int warmup, int repeat) +{ + auto launch_kernels = [&]() -> float{ + float ms = .0; + ms += epilog_kernel(); + for(const auto & ker : kernels) + ms += igemm_launch_kernel_single(ker.kernel_func, ker.args, ker.arg_size, ker.grid_size, ker.block_size); + return ms; + }; + + assert(repeat > 2); + std::vector duration_list; + for (int i = 0; i < warmup; i++) { + launch_kernels(); + } + + for (int i = 0; i < repeat; i++) { + float d = launch_kernels(); + duration_list.push_back(d); + } + // remove min and max from list, then do average + auto imin = std::min_element(begin(duration_list), end(duration_list)); + duration_list.erase(imin); + auto imax = std::max_element(begin(duration_list), end(duration_list)); + duration_list.erase(imax); + + assert(duration_list.size() == (repeat - 2)); + float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); + return avg_duration; +} + +static inline int igemm_get_max_gks(int gemm_k, int gemm_k_per_block, int max_log2_splits) +{ + int gks = gemm_k % gemm_k_per_block == 0 ? (int)log2(gemm_k / gemm_k_per_block) : 0; + if(gks > max_log2_splits) + gks = max_log2_splits; + return gks; +} + +// this is to support big tensor > 4G. need to decide how many splits needed +// return the number of splits +static inline size_t igemm_split_batch_size(const args_t *arg, int data_byte) +{ + int hi = arg->get_int("in_h"); + int wi = arg->get_int("in_w"); + int n = arg->get_int("batchsize"); + int k = arg->get_int("out_channels"); + int c = arg->get_int("in_channels"); + + int stride_h = arg->get_int("conv_stride_h"); + int stride_w = arg->get_int("conv_stride_w"); + int dilation_h = arg->get_int("dilation_h"); + int dilation_w = arg->get_int("dilation_w"); + int pad_h = arg->get_int("pad_h"); + int pad_w = arg->get_int("pad_w"); + int y = arg->get_int("fil_h"); + int x = arg->get_int("fil_w"); + int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h); + int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w); + + // int data_byte = utility_string_to_data_byte(tunable->precision); + size_t image_size_input = static_cast(c) * hi * wi * data_byte; + size_t image_size_output = static_cast(k) * ho * wo * data_byte; + size_t size_4g = 0xffffffffUL; + if(image_size_input >= size_4g || image_size_output >= size_4g) + return 0; + + size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output; + size_t splited_n = size_4g / image_size; + + // round up splits, we must match + // 1. splited_n * image_size < size_4g + // 2. n % splited_n == 0 + // if(splited_n >= n) + // return 1; + assert(splited_n != 0); + while(splited_n >= 1){ + // printf("n:%d, splited_n:%d\n", n, splited_n); + if(n % splited_n == 0) + break; + splited_n--; + } + + assert(splited_n * image_size < size_4g && n % splited_n == 0); + return static_cast(n) / splited_n; +} + +class igemm_driver_base_t{ +public: + igemm_driver_base_t(hipModule_t module_, driver_mode_t driver_mode_, driverDataType_t data_type_, int warmup_, int repeat_, bool verbose_) : + module(module_), driver_mode(driver_mode_), data_type(data_type_), warmup(warmup_), repeat(repeat_), verbose(verbose_) + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipGetDeviceProperties(&dev_prop, dev)); + this->num_cu = dev_prop.multiProcessorCount; + this->gcn_arch = dev_prop.gcnArch; + if(this->gcn_arch >= 1000) + this->num_cu *= 2; + } + std::string get_kernel_name(const igemm_gtc_tunable_t *tunable) { + return igemm_gtc_encode_kernel_name(tunable); + } + + virtual size_t get_block_size(const igemm_gtc_tunable_t *tunable) = 0; + virtual size_t get_grid_size(const args_t *arg, const igemm_gtc_tunable_t *tunable) = 0; + virtual bool tunable_is_valid(const args_t *arg, const igemm_gtc_tunable_t *tunable) = 0; + virtual result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, void *p_in, void *p_wei, void *p_out) = 0; + + virtual igemm_gtc_tunable_t heuristic_select_kernel(const args_t *arg) {return igemm_gtc_tunable_t{}; } + virtual int heuristic_select_gks(const args_t *arg, const igemm_gtc_tunable_t *tunable) {return 0; } + + hipModule_t module; + driver_mode_t driver_mode; + driverDataType_t data_type; + int warmup; + int repeat; + bool verbose; + + int num_cu; + int gcn_arch; +}; + #endif \ No newline at end of file diff --git a/driver/igemm_wrw_gtc_driver.h b/driver/igemm_wrw_gtc_driver.h index e63efae4..7bbd6bac 100644 --- a/driver/igemm_wrw_gtc_driver.h +++ b/driver/igemm_wrw_gtc_driver.h @@ -38,9 +38,9 @@ #include typedef struct { - float *p_in; - float *p_wei; - float *p_out; + void *p_in; + void *p_wei; + void *p_out; int hi; int wi; int n; @@ -85,83 +85,13 @@ static void dump_wrw_karg(igemm_wrw_gtc_karg_t * karg){ std::cout<gemm_m_per_block; - auto gemm_n_per_block = tunable->gemm_n_per_block; - auto gemm_k_per_block = tunable->gemm_k_per_block; - auto gemm_m_per_thread = tunable->gemm_m_per_thread; - auto gemm_m_level0_cluster = tunable->gemm_m_level0_cluster; - auto gemm_m_level1_cluster = tunable->gemm_m_level1_cluster; - auto gemm_n_per_thread = tunable->gemm_n_per_thread; - auto gemm_n_level0_cluster = tunable->gemm_n_level0_cluster; - auto gemm_n_level1_cluster = tunable->gemm_n_level1_cluster; - auto tensor_a_thread_lengths = tunable->tensor_a_thread_lengths; - auto tensor_a_cluster_lengths = tunable->tensor_a_cluster_lengths; - auto tensor_b_thread_lengths = tunable->tensor_b_thread_lengths; - auto tensor_b_cluster_lengths = tunable->tensor_b_cluster_lengths; - auto direction = tunable->direction; - auto precision = tunable->precision; - auto nxb = tunable->nxb; - auto nxe = tunable->nxe; - auto gemm_m_unmerge_cluster = tunable->gemm_m_unmerge_cluster; - auto gemm_n_unmerge_cluster = tunable->gemm_n_unmerge_cluster; - auto gemm_k_unmerge_cluster = tunable->gemm_k_unmerge_cluster; - auto multihead = tunable->multihead; - - assert(gemm_m_per_block % (gemm_m_per_thread * gemm_m_level0_cluster * gemm_m_level1_cluster) == 0); - assert(gemm_n_per_block % (gemm_n_per_thread * gemm_n_level0_cluster * gemm_n_level1_cluster) == 0); - int gemm_m_repeat = gemm_m_per_block / (gemm_m_per_thread * gemm_m_level0_cluster * gemm_m_level1_cluster); - int gemm_n_repeat = gemm_n_per_block / (gemm_n_per_thread * gemm_n_level0_cluster * gemm_n_level1_cluster); - - int thread_tile_m = gemm_m_repeat * gemm_m_per_thread; - int thread_tile_n = gemm_n_repeat * gemm_n_per_thread; - - assert(direction == "wrw"); - - std::string kernel_prefix = std::string("igemm_") + direction + std::string("_gtc_") + precision + - std::string("_bx") + std::to_string(nxb) + - std::string("_ex") + std::to_string(nxe) + "_"; - std::string kernel_name = - kernel_prefix + - "bt" + - std::to_string(gemm_m_per_block) + "x" + - std::to_string(gemm_n_per_block) + "x" + - std::to_string(gemm_k_per_block) + "_" + - "tt" + - std::to_string(thread_tile_m) + "x" + - std::to_string(thread_tile_n) + "_" + - "gm" + - std::to_string(gemm_m_repeat) + "x" + - std::to_string(gemm_m_level0_cluster) + "x" + - std::to_string(gemm_m_level1_cluster) + "_" + - "gn" + - std::to_string(gemm_n_repeat) + "x" + - std::to_string(gemm_n_level0_cluster) + "x" + - std::to_string(gemm_n_level1_cluster) + "_" + - "ta" + utility_int_list_to_string(tensor_a_thread_lengths) + "_" + - utility_int_list_to_string(tensor_a_cluster_lengths)+ "_" + - "tb" + utility_int_list_to_string(tensor_b_thread_lengths) + "_" + - utility_int_list_to_string(tensor_b_cluster_lengths); - // printf("[%s]\n",kernel_name.c_str()); - if(gemm_m_unmerge_cluster) - kernel_name += std::string("_mc"); - if(gemm_n_unmerge_cluster) - kernel_name += std::string("_nc"); - if(gemm_k_unmerge_cluster) - kernel_name += std::string("_kc"); - if(multihead) - kernel_name += std::string("_mh"); - return kernel_name; -#else - return igemm_gtc_encode_kernel_name(tunable); -#endif - } - int get_block_size(const igemm_gtc_tunable_t *tunable) { + + size_t get_block_size(const igemm_gtc_tunable_t *tunable) override { if(tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_MAC || tunable->fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS){ return tunable->gemm_m_level0_cluster * tunable->gemm_n_level0_cluster * tunable->gemm_m_level1_cluster * tunable->gemm_n_level1_cluster; @@ -177,8 +107,8 @@ class igemm_wrw_gtc_t { return 0; } } - int get_grid_size(const args_t *arg, - const igemm_gtc_tunable_t *tunable) { + size_t get_grid_size(const args_t *arg, + const igemm_gtc_tunable_t *tunable) override { int hi = arg->get_int("in_h"); int wi = arg->get_int("in_w"); int n = arg->get_int("batchsize"); @@ -224,7 +154,7 @@ class igemm_wrw_gtc_t { } bool tunable_is_valid(const args_t *arg, - const igemm_gtc_tunable_t *tunable) + const igemm_gtc_tunable_t *tunable) override { // TODO: int hi = arg->get_int("in_h"); @@ -634,8 +564,7 @@ class igemm_wrw_gtc_t { } result_t run(const args_t *arg, const igemm_gtc_tunable_t *tunable, - hipModule_t module, float *p_in, float *p_wei, float *p_out, - int warmup, int repeat) { + void *p_in, void *p_wei, void *p_out) override { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; @@ -698,8 +627,8 @@ class igemm_wrw_gtc_t { //printf("gemmk split is %d\r\n", 1 << gemm_k_global_split); - int block_size = get_block_size(tunable); - int grid_size = get_grid_size(arg, tunable); + size_t block_size = get_block_size(tunable); + size_t grid_size = get_grid_size(arg, tunable); hipFunction_t kernel_func; std::string kernel_name = get_kernel_name(tunable); @@ -710,89 +639,27 @@ class igemm_wrw_gtc_t { // hipMemset(p_wei, 0x0, group * (k / group) * (c / group) * y * x * sizeof(float)); - auto launch_wrw_driver = [&](){ - void *config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, &karg, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &karg_size, - HIP_LAUNCH_PARAM_END}; - float ms = .0; - - if(gemm_k_global_split){ - // TODO: current implementation of global split K need pre-clear the wei tensor - // This may not be true in the future! + auto wrw_epilog = gemm_k_global_split ? + std::function{[&]() -> float{ hipMemset(p_wei, 0x0, group * (k / group) * (c / group) * y * x * sizeof(float)); - } - -#if USE_EXT_MODULE_LAUNCH - hipEvent_t start; - hipEvent_t stop; - hipEventCreate(&start); - hipEventCreate(&stop); - - // for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem - HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config, start, stop)); - - hipEventSynchronize(stop); - hipEventElapsedTime(&ms, start, stop); - hipEventDestroy(start); - hipEventDestroy(stop); -#else - gpu_timer_t timer(NULL); - timer.start(); - - HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1, - block_size, 1, 1, 0, 0, NULL, - (void **)&config)); - - timer.stop(); - ms = timer.duration(); -#endif - return ms; - }; - - for (int i = 0; i < warmup; i++) { - launch_wrw_driver(); - } - std::vector duration_list; - for (int i = 0; i < repeat; i++) { - float d = launch_wrw_driver(); - duration_list.push_back(d); - } + return .0; + }} : + std::function{[&]() -> float{ + return .0; + }}; - // for (int i = 0; i < warmup; i++) { - // hipMemset(p_wei, 0x0, k * c * y * x * sizeof(float)); - // launch_wrw_driver(); - // } - - // remove min and max from list, then do average - auto imin = std::min_element(begin(duration_list), end(duration_list)); - duration_list.erase(imin); - auto imax = std::max_element(begin(duration_list), end(duration_list)); - duration_list.erase(imax); - assert(duration_list.size() == (repeat - 2)); - float avg_duration = std::accumulate(duration_list.begin(), duration_list.end(), (float).0) / duration_list.size(); - - usleep(1000 * 1); - - // debug section of code -#if 0 - printf("workspace debug \r\n"); - float* gemmc_host_check = (float* )malloc((1 << gemm_k_global_split) * k * c * y * x * sizeof(float)); - hipMemcpy(gemmc_host_check, p_wei, k * c * y * x * sizeof(float), hipMemcpyDeviceToHost); - for (int i_check = 0; i_check < (0+block_size); i_check++) - { - printf("[%d]th var to monitor:[%f, %d]\r\n", i_check, gemmc_host_check[i_check], ((int *)gemmc_host_check)[i_check]); - } - printf("workspace debug end \r\n"); -#endif result_t result; + float duration = igemm_launch_kernels_with_epilog({ + {kernel_func, &karg, karg_size, {grid_size * block_size, 1, 1}, {block_size, 1, 1}} + }, wrw_epilog, this->warmup, this->repeat); + result.return_code = 0; - result.duration_ms = avg_duration; + result.duration_ms = duration; + result.gks = gemm_k_global_split; result.kernel_name = kernel_name; + return result; } }; - #endif \ No newline at end of file diff --git a/igemm/algo/igemm_fwd_gtc_nhwc.py b/igemm/algo/igemm_fwd_gtc_nhwc.py index a4237b1c..fab8d3a0 100755 --- a/igemm/algo/igemm_fwd_gtc_nhwc.py +++ b/igemm/algo/igemm_fwd_gtc_nhwc.py @@ -110,6 +110,7 @@ def flatten(x): ctrl_coalescing_store_xdlops = ctrl_coalescing_store_xdlops_t() ctrl_coalescing_store_xdlops.cxm = ctrl_xdlops_mapping + ctrl_coalescing_store_xdlops.gemm_k_global_split = self.tunable.gemm_k_global_split ctrl_coalescing_store_xdlops.coalescing_groups = self.coalescing_store_groups ctrl_coalescing_store_xdlops.data_byte = amdgpu_precision_data_byte(self.tunable.precision) @@ -242,9 +243,12 @@ def __init__(self, mc, tunable, inline, **options): if tunable.tensor_a_pass_through: self.declare_arg("s_in_base") self.declare_arg("s_in_c_itr") # - self.declare_arg("s_gemm_k_num_c") # used to U64 sub s_in_base, can be None + #self.declare_arg("s_gemm_k_num_c") # used to U64 sub s_in_base, can be None else: self.declare_arg("s_in_offset") # use this as c itr, since other dimension of input is voffset + self.declare_arg("v_wei_os") + self.declare_arg("s_c") + self.declare_arg("s_gemm_k_num_c") self.declare_arg("v_in_os") self.declare_arg("v_in_ihi_list") self.declare_arg("v_in_iwi_list") @@ -283,6 +287,9 @@ def expr(self): self._emit(f"s_cmp_eq_u32 1, s[{self.s_flag_need_acc_yx()}]") self._emit(f"s_cbranch_scc0 {label_acc_yx_end} ; no need do accumulate yx") self._emit_front(f"{label_acc_yx}:") + # wei os need to add a whole c when yx is changing + self._emit(f"v_add_u32 v[{self.v_wei_os()}], v[{self.v_wei_os()}], s[{self.s_c()}]") + self._emit(f"v_sub_u32 v[{self.v_wei_os()}], v[{self.v_wei_os()}], s[{self.s_gemm_k_num_c()}]") if self.tunable.tensor_a_pass_through: self._emit(f"s_sub_u32 s[{self.s_in_base()}], s[{self.s_in_base()}], s[{self.s_gemm_k_num_c()}]") self._emit(f"s_subb_u32 s[{self.s_in_base(1)}], s[{self.s_in_base(1)}], 0") @@ -444,7 +451,8 @@ def __init__(self, mc, outer): self.k_magic_2 = sym_t('k_magic_2' ,96) self.k_magic_3 = sym_t('k_magic_3' ,100) self.k_shift_pack_0 = sym_t('k_shift_pack_0' ,104) - self.k__pack_0 = sym_t('k__pack_0' ,108) + self.k_gemm_k_global_split = sym_t("k_gemm_k_global_split", 108) + #self.k__pack_0 = sym_t('k__pack_0' ,108) self.k_end = sym_t('k_end' ,112) else: self.k_end = sym_t('k_end' ,88) @@ -498,6 +506,7 @@ def __init__(self, mc, outer): self.s_out_stride_wo = sym_t('s_out_stride_wo' , sseq(1)) self.s_out_stride_n = sym_t('s_out_stride_n' , sseq(1)) + self.s_block_gtc_ic = sym_t("s_block_gtc_ic" , sseq(1)) # add c split self.s_block_gtc_ig = sym_t("s_block_gtc_ig" , sseq(1)) self.s_block_gtc_ik = sym_t("s_block_gtc_ik" , sseq(1)) self.s_block_gtc_inb = sym_t("s_block_gtc_inb" , sseq(1)) @@ -542,7 +551,9 @@ def __init__(self, mc, outer): self.s_magic_2 = sym_t("s_magic_2" ,self.s_p_out.value + 2) self.s_magic_3 = sym_t("s_magic_3" ,self.s_p_out.value + 3) self.s_shift_pack_0 = sym_t("s_shift_pack_0" ,self.s_flag_need_acc_yx.value) - + + self.s_gemmk_split = sym_t("s_gemmk_split" ,sseq(1)) + self.s_sub_c = sym_t("s_sub_c" ,sseq(1)) self.s_tmp = sym_t("s_tmp" ,sseq(6, 2)) self.s_end = sym_t("s_end" ,sseq()) @@ -1083,6 +1094,7 @@ def emit_kernel_prologue(self): self._emit(f"s_load_dwordx2 s[{s.s_magic_0((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_0()}") self._emit(f"s_load_dwordx2 s[{s.s_magic_2((0, 1))}], s[{s.s_ka((0, 1))}], 0+{k.k_magic_2()}") self._emit(f"s_load_dword s[{s.s_shift_pack_0()}], s[{s.s_ka((0, 1))}], 0+{k.k_shift_pack_0()}") + self._emit(f"s_load_dword s[{s.s_gemmk_split()}], s[{s.s_ka((0, 1))}], 0+{k.k_gemm_k_global_split()}") self._emit(f"; in(e, c, nb0, nb1) thread_lengths: {ta_e}x{ta_c}x{ta_nb0}x{ta_nb1}, cluster_length: {ca_e}x{ca_c}x{ca_nb0}x{ca_nb1}, k_pack:{k_pack}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") @@ -1114,6 +1126,7 @@ def emit_kernel_prologue(self): self._emit(f"; calculate index") # calculate stride, not shift data byte yet # input + self._emit(f"s_lshr_b32 s[{s.s_sub_c()}], s[{s.s_c()}], s[{s.s_gemmk_split()}] ;add gkgs for c") self._emit(f"s_mul_i32 s[{s.s_in_stride_wi()}], s[{s.s_c()}], s[{s.s_group()}]") self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_wi()}], s[{s.s_in_stride_wi()}]") self._emit(f"s_mul_i32 s[{s.s_in_stride_n()}], s[{s.s_hi()}], s[{s.s_tmp(2)}]") @@ -1150,7 +1163,7 @@ def emit_kernel_prologue(self): self._emit(f"s_addc_u32 s[{s.s_p_out(1)}], s[{s.s_p_out(1)}], s[{s.s_tmp(1)}]") # early init s_knum in case shifted - self._emit(f"s_mov_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}]") + self._emit(f"s_lshr_b32 s[{s.s_knum()}], s[{s.s_wei_stride_k()}], s[{s.s_gemmk_split()}]") # pad gemm_m, gemm_n if self.tunable.nxe != 0: @@ -1174,6 +1187,12 @@ def emit_kernel_prologue(self): self._emit(f"s_lshr_b32 s[{s.s_tmp()}], s[{s.s_dim_mp()}], {igemm_log2(self.tunable.gemm_m_per_block)}") self._emit(f"s_lshr_b32 s[{s.s_tmp(1)}], s[{s.s_dim_np()}], {igemm_log2(self.tunable.gemm_n_per_block)}") self._emit(f"s_mul_i32 s[0], s[{s.s_tmp(1)}], s[{s.s_tmp()}]") + # calculate block ic + self._emit(f"s_lshl_b32 s[{s.s_tmp(3)}], 1, s[{s.s_gemmk_split()}]") + self._emit(f"s_sub_u32 s[{s.s_tmp(3)}], s[{s.s_tmp(3)}], 1") + self._emit(f"s_and_b32 s[{s.s_block_gtc_ic()}], s[{s.s_bx()}], s[{s.s_tmp(3)}]") + self._emit(f"s_lshr_b32 s[{s.s_bx()}], s[{s.s_bx()}], s[{s.s_gemmk_split()}]") + self._emit(f"s_mul_i32 s[{s.s_block_gtc_ic()}], s[{s.s_block_gtc_ic()}], s[{s.s_sub_c()}]") if IGEMM_GTC_FEAT_MAGIC_DIVISION: self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_0()}], 0x00080018 ; offset:24, width:8") self._emit(m_mdiv_u32_ss(s.s_tmp(4), s.s_block_gtc_ig(), s.s_bx(), s.s_magic_3(), s.s_tmp(3), '0', s.s_tmp())) @@ -1255,6 +1274,8 @@ def calculate_and_load_input(): # s_in_stride_wi need shift before! self._emit(self.try_shift_stride(s.s_in_stride_wi, igemm_log2(data_byte))) + self._emit(f"v_add_u32 v[{v.v_tmp(1)}], v[{v.v_tmp(1)}], s[{s.s_block_gtc_ic()}]") + self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(0)}]") self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(0)}], v[{v.v_tmp()}]") @@ -1310,6 +1331,7 @@ def calculate_and_load_input(): self._emit(m_int_div_rem_vs(v.v_in_iwi_list(i), v.v_in_ihi_list(i), v.v_tmp(4), s.s_wi(), v.v_tmp(), s.s_tmp())) self._emit(f"v_mul_lo_u32 v[{v.v_tmp(1)}], s[{s.s_in_stride_n()}], v[{v.v_in_in()}]") + self._emit(f"v_add_u32 v[{v.v_tmp(1)}], v[{v.v_tmp(1)}], s[{s.s_block_gtc_ic()}]") self._emit(f"v_add_lshl_u32 v[{v.v_tmp(4)}], v[{v.v_gtc_ic_a() if self.tunable.tensor_a_pass_through else v.v_gtc_ic()}], v[{v.v_tmp(1)}], {igemm_log2(data_byte)}") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wi()}], v[{v.v_in_ihi_list(i)}]") self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_in_iwi_list(i)}], v[{v.v_tmp()}]") @@ -1352,6 +1374,7 @@ def calculate_and_load_weight(): self._emit(f"v_add_u32 v[{v.v_tmp(5)}], s[{s.s_block_gtc_ik()}], v[{v.v_wei_ik()}]") self._emit(f"v_mul_lo_u32 v[{v.v_tmp()}], s[{s.s_wei_stride_k()}], v[{v.v_tmp(5)}]") + self._emit(f"v_add_u32 v[{v.v_tmp()}], v[{v.v_tmp()}], s[{s.s_block_gtc_ic()}]") self._emit(f"v_add_lshl_u32 v[{v.v_wei_os()}], v[{v.v_tmp()}], v[{v.v_gtc_ic()}], {igemm_log2(data_byte)}") # wei flag @@ -1484,7 +1507,8 @@ def calculate_and_load_weight(): self._emit(f"v_add_u32 v[{v.v_out_os()}], v[{v.v_out_os()}], v[{v.v_tmp()}]") self._emit(f"; move slice stride") - self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_gemm_k_num_c()}], s[{s.s_sub_c()}], {igemm_log2(data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_c()}], s[{s.s_c()}], {igemm_log2(data_byte)}") w_flag_cnt = 0 self._emit(f"v_bfe_u32 v[{v.v_wei_flag(0)}], v[{v.v_wei_tmp_pack()}], {0}, 1") @@ -1584,7 +1608,10 @@ def move_slice_window_acc(): s.s_tmp())) else: self._emit(m_move_slice_window_accumulate( - *(s.s_p_in(), s.s_in_c_itr(), s.s_gemm_k_num_c()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), + *(s.s_p_in(), s.s_in_c_itr()) if self.tunable.tensor_a_pass_through else (s.s_in_offset(),), + v.v_wei_os(), + s.s_c(), + s.s_gemm_k_num_c(), v.v_in_os(), v.v_in_ihi_list(), v.v_in_iwi_list(), diff --git a/retieve_perf_data.py b/retieve_perf_data.py new file mode 100644 index 00000000..3b301470 --- /dev/null +++ b/retieve_perf_data.py @@ -0,0 +1,77 @@ +import re + +def get_log_lines(log_file_name): + with open(log_file_name, 'r') as f: + log_lines = f.readlines() + return log_lines + +def get_runtime(log_file_name): + txt_lines = get_log_lines(log_file_name) + min_cost = [] + min_kernel = [] + sel_kernel = [] + sel_costs = [] + + for each_line in txt_lines: + res_str = re.search(r'(?<=fastest:).*', each_line) + if res_str: + res_str = re.search(r'(?<=tflops:)\d+\.?\d*', each_line) + if res_str: + driver_cost = float(res_str.group()) + print(f"driver_cost={driver_cost}") + min_cost.append(driver_cost) + res_str = re.search(r'kernel_name:', each_line) + if res_str: + min_kernel.append(each_line.split(":")[-1][:-1]) + print(each_line.split(":")[-1][:-1]) + res_str = re.search(r'selected kernel:', each_line) + if res_str: + sel_kernel.append(each_line.split(":")[-1][:-1]) + print(each_line.split(":")[-1][:-1]) + + res_str = re.search(r'(?<=selected cost:)\d+\.?\d*', each_line) + if res_str: + sel_cost = float(res_str.group()) + print(f"driver_cost={sel_cost}") + sel_costs.append(sel_cost) + + with open("./wrw_model.csv", "w") as f: + for num_cost in min_cost: + f.write(f"{num_cost}") + f.write("\n") + for sel_cost in sel_costs: + f.write(f"{sel_cost}") + f.write("\n") + for kernel_name in min_kernel: + f.write(f"{kernel_name}") + f.write("\n") + for s_kernel_name in sel_kernel: + f.write(f"{s_kernel_name}") + f.write("\n") + +def get_kernel_name(log_file_name): + txt_lines = get_log_lines(log_file_name) + section_lines = [] + store_line = 0 + kernel_names = [] + for each_line in txt_lines: + conv_line = re.match(r"(?!#)./out/conv_driver.exe conv", each_line) + if store_line: + kernel_name = re.match(r"kernel:", each_line) + if kernel_name: + name = each_line.split(':')[-1] + print(f"\t\"{name[:-1]}\",") + kernel_names.append(name) + if conv_line: + store_line = 1 + conv_end_line = re.match(r"min cost:", each_line) + if conv_end_line: + break + + return kernel_names + + +if __name__ == '__main__': + #names = get_kernel_name("./wrw_model.log") + #print(len(names)) + get_runtime("./fwd_fp32_nhwc_models.log") diff --git a/script/gtc_conv_model.sh b/script/gtc_conv_model.sh index b2a7ffba..95e272f8 100755 --- a/script/gtc_conv_model.sh +++ b/script/gtc_conv_model.sh @@ -1,14 +1,33 @@ #!/bin/sh -if [ $# -ne 1 ] +if [ $# -lt 1 ] then - echo "please give this script a direction" - echo "now I use bwd as default" - DIR=bwd + DIR=bwd else DIR=$1 fi -export IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908.hsaco + +if [ $# -eq 2 ] +then + LAYOUT=$2 +else + LAYOUT="nchw" +fi + +if [ "${LAYOUT}" = "nchw" ] +then + LAYOUT_HSACO="" + LAYOUT_ARG="" +elif [ "${LAYOUT}" = "nhwc" ] +then + LAYOUT_HSACO="_nhwc" + LAYOUT_ARG="--in_layout NHWC --fil_layout NHWC --out_layout NHWC" +else + echo "wrong layout: ${LAYOUT}" + exit 1 +fi +echo IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908${LAYOUT_HSACO}.hsaco +export IGEMM_HSACO=out/igemm_${DIR}_gtc_gfx908${LAYOUT_HSACO}.hsaco export IGEMM_GPU_NAIVE_CONV_HSACO=out/naive_conv.hsaco export IGEMM_SCLK_MHZ=1283 export IGEMM_LOG_FASTEST_CONFIG=1 @@ -29,7 +48,8 @@ else fi # only forward support gemm_k_padding -if [ $FORW = 1 ] +#if [ $FORW = 1 ] +if [ 0 = 1 ] then ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW @@ -44,172 +64,171 @@ then ./out/conv_driver.exe conv -n 64 -c 512 -H 14 -W 14 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 32 -F $FORW ./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 32 -F $FORW #exit 1 - fi #resnext101 -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 224 -W 224 -k 64 -y 7 -x 7 -p 3 -q 3 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW ${LAYOUT_ARG} #inception4 batch_size=128 -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 32 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 448 -H 8 -W 8 -k 384 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 48 -H 35 -W 35 -k 64 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 64 -H 73 -W 73 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 160 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 80 -H 73 -W 73 -k 192 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 160 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 160 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 192 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 32 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 192 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 320 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 448 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 256 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 48 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 128 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 384 -H 8 -W 8 -k 384 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 448 -H 8 -W 8 -k 384 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 48 -H 35 -W 35 -k 64 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 64 -H 73 -W 73 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 160 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 80 -H 73 -W 73 -k 192 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 128 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #inception3 batch_size=64 -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 160 -H 73 -W 73 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 35 -W 35 -k 224 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 192 -H 71 -W 71 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 256 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 224 -H 35 -W 35 -k 256 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 256 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 320 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 320 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 96 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 448 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 448 -H 8 -W 8 -k 512 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 147 -W 147 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 96 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 17 -W 17 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 160 -H 73 -W 73 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 17 -W 17 -k 224 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 35 -W 35 -k 224 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 192 -H 71 -W 71 -k 192 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 224 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 17 -W 17 -k 256 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 224 -H 35 -W 35 -k 256 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 256 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 17 -W 17 -k 320 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 299 -W 299 -k 32 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 32 -H 147 -W 147 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 32 -H 149 -W 149 -k 32 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 320 -H 17 -W 17 -k 320 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 192 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 35 -W 35 -k 96 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 384 -H 8 -W 8 -k 448 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 448 -H 8 -W 8 -k 512 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 1 -x 3 -p 0 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 8 -W 8 -k 256 -y 3 -x 1 -p 1 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 147 -W 147 -k 96 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 64 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 73 -W 73 -k 96 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 96 -H 35 -W 35 -k 96 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #resnet50 -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 128 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 128 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -# ./out/conv_driver.exe conv -n 64 -c 3 -H 230 -W 230 -k 64 -y 7 -x 7 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 256 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 1024 -H 14 -W 14 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 128 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 128 -H 28 -W 28 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 2048 -H 7 -W 7 -k 512 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 128 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 256 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +# ./out/conv_driver.exe conv -n 64 -c 3 -H 230 -W 230 -k 64 -y 7 -x 7 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 1024 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 28 -W 28 -k 256 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 512 -H 7 -W 7 -k 512 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #from v4r1_origin_conv.sh -./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1024 -H 17 -W 17 -k 1024 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 256 -H 34 -W 34 -k 256 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 35 -W 35 -k 128 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 512 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 256 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 48 -H 7 -W 7 -k 128 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW -./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW +#./out/conv_driver.exe conv -n 64 -c 64 -H 56 -W 56 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 1024 -H 17 -W 17 -k 1024 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 256 -H 34 -W 34 -k 256 -y 3 -x 3 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 35 -W 35 -k 128 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 2048 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 1280 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 512 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 64 -c 1536 -H 8 -W 8 -k 384 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 256 -H 28 -W 28 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 768 -H 17 -W 17 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 528 -H 14 -W 14 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 832 -H 7 -W 7 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 288 -H 35 -W 35 -k 384 -y 3 -x 3 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 48 -H 7 -W 7 -k 128 -y 5 -x 5 -p 2 -q 2 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 1 -x 7 -p 0 -q 3 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} +#./out/conv_driver.exe conv -n 128 -c 128 -H 17 -W 17 -k 128 -y 7 -x 1 -p 3 -q 0 -u 1 -v 1 -l 1 -j 1 -F $FORW ${LAYOUT_ARG} #mask rcnn -./out/conv_driver.exe conv -n 2 -c 256 -H 12 -W 18 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 1024 -H 34 -W 84 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 1024 -H 40 -W 52 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 256 -H 100 -W 104 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 256 -H 10 -W 20 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 64 -H 71 -W 83 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 2 -c 64 -H 59 -W 57 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 4 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 4 -c 256 -H 28 -W 28 -k 256 -y 2 -x 2 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 3 -c 256 -H 28 -W 28 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 1 -c 256 -H 32 -W 64 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 1 -c 64 -H 17 -W 17 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 12 -W 18 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 1024 -H 34 -W 84 -k 256 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 1024 -H 40 -W 52 -k 512 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 100 -W 104 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 256 -H 10 -W 20 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 64 -H 71 -W 83 -k 128 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 2 -c 64 -H 59 -W 57 -k 12 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 4 -c 256 -H 14 -W 14 -k 256 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 4 -c 256 -H 28 -W 28 -k 256 -y 2 -x 2 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 3 -c 256 -H 28 -W 28 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 1 -c 256 -H 32 -W 64 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 1 -c 64 -H 17 -W 17 -k 80 -y 1 -x 1 -p 0 -q 0 -u 1 -v 1 -l 1 -j 1 -g 1 -F $FORW #retina net bs=16 -./out/conv_driver.exe conv -n 16 -c 256 -H 12 -W 12 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 16 -c 256 -H 134 -W 77 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW -./out/conv_driver.exe conv -n 16 -c 256 -H 71 -W 101 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 12 -W 12 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 134 -W 77 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW +#./out/conv_driver.exe conv -n 16 -c 256 -H 71 -W 101 -k 256 -y 3 -x 3 -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -g 1 -F $FORW From cf45e48489388be07c6f3325a393220781450cc4 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Mar 2021 22:41:25 +0800 Subject: [PATCH 37/40] fix a bug in find max gks --- driver/igemm_fwd_gtc_driver.h | 36 ++++++++++++++++++++++++++++++++++- driver/igemm_gtc_base.h | 8 +++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/driver/igemm_fwd_gtc_driver.h b/driver/igemm_fwd_gtc_driver.h index 99f1758c..514a0c8c 100755 --- a/driver/igemm_fwd_gtc_driver.h +++ b/driver/igemm_fwd_gtc_driver.h @@ -141,6 +141,37 @@ static void dump_fwd_karg(igemm_fwd_gtc_karg_t * karg){ std::cout<p_in<<","; + std::cout<<"p_wei:" <p_wei<<","; + std::cout<<"p_out:" <p_out<<","; + std::cout<<"hi:" <hi<<","; + std::cout<<"wi:" <wi<<","; + std::cout<<"n:" <n<<","; + std::cout<<"k:" <k<<","; + std::cout<<"c:" <c<<","; + std::cout<<"ho:" <ho<<","; + std::cout<<"wo:" <wo<<","; + std::cout<<"stride_h:" <stride_h<<","; + std::cout<<"stride_w:" <stride_w<<","; + std::cout<<"dilation_h:" <dilation_h<<","; + std::cout<<"dilation_w:" <dilation_w<<","; + std::cout<<"pad_h:" <pad_h<<","; + std::cout<<"pad_w:" <pad_w<<","; + std::cout<<"y:" <y<<","; + std::cout<<"x:" <x<<","; + std::cout<<"group:" <group<<","; +#if USE_MAGIC_DIV + std::cout<<"magic_0:" <magic_0<<","; + std::cout<<"magic_1:" <magic_1<<","; + std::cout<<"magic_2:" <magic_2<<","; + std::cout<<"magic_3:" <magic_3<<","; + std::cout<<"shift_pack_0:" <shift_pack_0<<","; +#endif + std::cout<<"ks:" <ks; + std::cout<gemm_k_global_split == 0 ? - 0 : igemm_get_max_gks(c, tunable->gemm_k_per_block, MAX_GEMM_K_SPLITS); + 0 : igemm_get_max_gks(c / group, tunable->gemm_k_per_block, MAX_GEMM_K_SPLITS); for(int gks = 0; gks <= max_split_num; gks++){ size_t grid_size = get_grid_size(arg, tunable) * (1 << gks); if(tunable->tensor_layout == "nhwc"){ // This is hacky, but in MIOpen we prefer a heuristic way to set gks, so ok now. igemm_fwd_gtc_nhwc_karg_t *karg_revalue = (igemm_fwd_gtc_nhwc_karg_t *)(karg_buffer); karg_revalue->ks = gks; + // dump_fwd_karg(karg_revalue); + // printf("block:%d, grid:%d\n", block_size, grid_size); + // fflush(stdout); } float duration = igemm_launch_kernels_with_epilog({ {kernel_func, karg_buffer, karg_size, {grid_size * block_size, splits, 1}, {block_size, 1, 1}} diff --git a/driver/igemm_gtc_base.h b/driver/igemm_gtc_base.h index 0409ec8c..00cdf1ef 100755 --- a/driver/igemm_gtc_base.h +++ b/driver/igemm_gtc_base.h @@ -454,7 +454,13 @@ static inline float igemm_launch_kernels_with_epilog(const std::vector max_log2_splits) gks = max_log2_splits; return gks; From 2cba95cfd315287228b4e5ed32c2e0aedcac78e9 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Mar 2021 23:26:42 +0800 Subject: [PATCH 38/40] add missing configs --- config/igemm_fwd_gtc_gfx908_nhwc.config | 126 ++++++++---------------- 1 file changed, 42 insertions(+), 84 deletions(-) diff --git a/config/igemm_fwd_gtc_gfx908_nhwc.config b/config/igemm_fwd_gtc_gfx908_nhwc.config index 41e3e81e..bfb639e1 100644 --- a/config/igemm_fwd_gtc_gfx908_nhwc.config +++ b/config/igemm_fwd_gtc_gfx908_nhwc.config @@ -136,6 +136,7 @@ tensor_layout = 'nhwc' nxb = 0 nxe = 0 + #--------------------------- 128x256 [igemm_fwd_gtc] gemm_m_per_block = 128 @@ -156,7 +157,7 @@ direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 -nxe = 1 +nxe = 0 #--------------------------- 128x256 [igemm_fwd_gtc] @@ -178,7 +179,8 @@ direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 -nxe = 0 +nxe = 1 +gemm_k_global_split = 1 #--------------------------- 128x256 [igemm_fwd_gtc] @@ -200,98 +202,31 @@ direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 -nxe = 1 +nxe = 0 gemm_k_global_split = 1 -#--------------------------- 128x256 + +#--------------------------- 128x128 [igemm_fwd_gtc] gemm_m_per_block = 128 -gemm_n_per_block = 256 +gemm_n_per_block = 128 gemm_k_per_block = 16 wave_tile_m = 32 wave_step_m = 1 wave_repeat_m = 2 -wave_tile_n = 64 +wave_tile_n = 32 wave_step_n = 1 wave_repeat_n = 2 -wave_tile_k = 1 +wave_tile_k = 2 tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 -tensor_b_thread_lengths = [1, 4, 4, 1] # ExCxK0xK1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" tensor_layout = 'nhwc' nxb = 0 nxe = 0 -gemm_k_global_split = 1 - -#--------------------------- 128x128 -#[igemm_fwd_gtc] -#gemm_m_per_block = 128 -#gemm_n_per_block = 128 -#gemm_k_per_block = 16 -#wave_tile_m = 32 -#wave_step_m = 1 -#wave_repeat_m = 2 -#wave_tile_n = 32 -#wave_step_n = 1 -#wave_repeat_n = 2 -#wave_tile_k = 2 -#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 -#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 -#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 -#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 -#direction = "fwd" -#precision = "fp32" -#tensor_layout = 'nhwc' -#nxb = 0 -#nxe = 0 - -#--------------------------- 128x128 -#[igemm_fwd_gtc] -#gemm_m_per_block = 128 -#gemm_n_per_block = 128 -#gemm_k_per_block = 16 -#wave_tile_m = 32 -#wave_step_m = 1 -#wave_repeat_m = 2 -#wave_tile_n = 32 -#wave_step_n = 1 -#wave_repeat_n = 2 -#wave_tile_k = 2 -#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 -#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 -#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 -#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 -#direction = "fwd" -#precision = "fp32" -#tensor_layout = 'nhwc' -#nxb = 0 -#nxe = 1 - -#--------------------------- 128x128 -#[igemm_fwd_gtc] -#gemm_m_per_block = 128 -#gemm_n_per_block = 128 -#gemm_k_per_block = 16 -#wave_tile_m = 32 -#wave_step_m = 1 -#wave_repeat_m = 2 -#wave_tile_n = 32 -#wave_step_n = 1 -#wave_repeat_n = 2 -#wave_tile_k = 2 -#tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 -#tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 -#tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 -#tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 -#direction = "fwd" -#precision = "fp32" -#tensor_layout = 'nhwc' -#nxb = 0 -#nxe = 0 -#gemm_k_global_split = 1 #--------------------------- 128x128 [igemm_fwd_gtc] @@ -314,24 +249,22 @@ precision = "fp32" tensor_layout = 'nhwc' nxb = 0 nxe = 1 -gemm_k_global_split = 1 -#--------------------------- 128x64 +#--------------------------- 128x128 [igemm_fwd_gtc] gemm_m_per_block = 128 -gemm_n_per_block = 64 +gemm_n_per_block = 128 gemm_k_per_block = 16 wave_tile_m = 32 wave_step_m = 1 -wave_repeat_m = 1 +wave_repeat_m = 2 wave_tile_n = 32 wave_step_n = 1 wave_repeat_n = 2 wave_tile_k = 2 -tensor_a_pass_through = 1 -tensor_a_thread_lengths = [1, 8, 1, 1] # ExCxNB0xNB1 -tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 -tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_a_thread_lengths = [1, 4, 2, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 4, 1, 64] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 2, 1] # ExCxK0xK1 tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 direction = "fwd" precision = "fp32" @@ -340,6 +273,7 @@ nxb = 0 nxe = 1 gemm_k_global_split = 1 + #--------------------------- 64x128 [igemm_fwd_gtc] gemm_m_per_block = 64 @@ -429,6 +363,30 @@ tensor_layout = 'nhwc' nxb = 0 nxe = 1 +#--------------------------- 128x64 +[igemm_fwd_gtc] +gemm_m_per_block = 128 +gemm_n_per_block = 64 +gemm_k_per_block = 16 +wave_tile_m = 32 +wave_step_m = 1 +wave_repeat_m = 1 +wave_tile_n = 32 +wave_step_n = 1 +wave_repeat_n = 2 +wave_tile_k = 2 +tensor_a_pass_through = 1 +tensor_a_thread_lengths = [1, 8, 1, 1] # ExCxNB0xNB1 +tensor_a_cluster_lengths = [1, 2, 4, 32] # ExCxNB0xNB1 +tensor_b_thread_lengths = [1, 4, 1, 1] # ExCxK0xK1 +tensor_b_cluster_lengths = [1, 4, 1, 64] # ExCxK0XK1 +direction = "fwd" +precision = "fp32" +tensor_layout = 'nhwc' +nxb = 0 +nxe = 1 +gemm_k_global_split = 1 + # #--------------------------- 128x64 # [igemm_fwd_gtc] # gemm_m_per_block = 128 From 2dc86a85ec0fc2adf2441ffe4a9c0ce40e8c2f56 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 22 Mar 2021 23:19:13 +0800 Subject: [PATCH 39/40] Nhwc inference test gfx1030 (#90) * 256x8 * 512x8 * optimize double global prefetch * 512x8x8 * 1024x8 * 512x4 * 256x4 * 128x4, 384x4 * 256x8x8 * int8x4 * 512x8 int8 * update 256x8x16, 512x16x8 * 1024x16x8 * 256x4 --- test/inference/build.sh | 3 +- test/inference/igemm_fwd_btm_nhwc.h | 86 +- .../kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm | 493 +++++ .../fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm | 1266 ++++++++++++ .../fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm | 666 +++++++ .../igemm_fwd_btm_nhwc_fp16_128x016.asm | 32 +- .../fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm | 652 ++++++ .../fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm | 1520 ++++++++++++++ .../igemm_fwd_btm_nhwc_fp16_256x016.asm | 28 +- .../fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm | 779 ++++++++ .../fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm | 886 +++++++++ .../fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm | 1756 +++++++++++++++++ .../kernel/igemm_fwd_btm_nhwc_fp16.asm | 153 -- .../kernel/int8/igemm_fwd_btm_nhwc_int8.asm | 424 ++++ .../int8/igemm_fwd_btm_nhwc_int8_1024x016.asm | 1081 ++++++++++ .../int8/igemm_fwd_btm_nhwc_int8_256x004.asm | 738 +++++++ .../int8/igemm_fwd_btm_nhwc_int8_256x008.asm | 585 ++++++ .../int8/igemm_fwd_btm_nhwc_int8_512x008.asm | 757 +++++++ .../int8/igemm_fwd_btm_nhwc_int8_512x016.asm | 1545 +++++++++++++++ test/inference/test_inference.cpp | 193 +- 20 files changed, 13395 insertions(+), 248 deletions(-) create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm rename test/inference/kernel/{ => fp16}/igemm_fwd_btm_nhwc_fp16_128x016.asm (98%) create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm rename test/inference/kernel/{ => fp16}/igemm_fwd_btm_nhwc_fp16_256x016.asm (98%) create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm create mode 100644 test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm delete mode 100644 test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm create mode 100644 test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm diff --git a/test/inference/build.sh b/test/inference/build.sh index 7506c183..11fa1612 100644 --- a/test/inference/build.sh +++ b/test/inference/build.sh @@ -5,7 +5,8 @@ rm -rf out mkdir out /opt/rocm/hip/bin/hipcc --amdgpu-target=$ARCH -Idriver -std=c++14 -lpthread test/inference/test_inference.cpp -o out/test_inference.exe || exit 1 -/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/ test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/fp16/ test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm -o out/igemm_fwd_btm_nhwc_fp16.hsaco || exit 1 +/opt/rocm/llvm/bin/clang++ -x assembler -target amdgcn--amdhsa -mcpu=$ARCH -mcumode -Itest/inference/kernel/int8/ test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm -o out/igemm_fwd_btm_nhwc_int8.hsaco || exit 1 /opt/rocm/hip/bin/hipcc -x hip --cuda-gpu-arch=$ARCH --cuda-device-only -c -O3 driver/gpu_naive_conv/naive_conv.cpp -o out/naive_conv.hsaco diff --git a/test/inference/igemm_fwd_btm_nhwc.h b/test/inference/igemm_fwd_btm_nhwc.h index aa217e0a..6707f476 100644 --- a/test/inference/igemm_fwd_btm_nhwc.h +++ b/test/inference/igemm_fwd_btm_nhwc.h @@ -77,8 +77,8 @@ typedef struct { uint32_t stride_m; uint32_t magic_0; uint32_t magic_1; + uint32_t magic_2; uint32_t shift_pack_0; - uint32_t __pack_0; } __attribute__((packed)) igemm_fwd_btm_2d_karg_t; static inline void dump_igemm_fwd_btm_2d_karg(igemm_fwd_btm_2d_karg_t * karg) { @@ -105,11 +105,13 @@ static inline void dump_igemm_fwd_btm_2d_karg(igemm_fwd_btm_2d_karg_t * karg) std::cout<<"stride_m:"<stride_m<<", "; std::cout<<"magic_0:"<magic_0<<", "; std::cout<<"magic_1:"<magic_1<<", "; + std::cout<<"magic_2:"<magic_2<<", "; std::cout<<"shift_pack_0:"<shift_pack_0<kernel_name; } + bool is_valid(const args_t *arg, igemm_fwd_btm_kernel_info_t * kernel_info) + { + size_t hi = arg->get_int("in_h"); + size_t wi = arg->get_int("in_w"); + size_t n = arg->get_int("batchsize"); + size_t k = arg->get_int("out_channels"); + size_t c = arg->get_int("in_channels"); + + size_t sy = arg->get_int("conv_stride_h"); + size_t sx = arg->get_int("conv_stride_w"); + size_t dy = arg->get_int("dilation_h"); + size_t dx = arg->get_int("dilation_w"); + size_t py = arg->get_int("pad_h"); + size_t px = arg->get_int("pad_w"); + size_t fy = arg->get_int("fil_h"); + size_t fx = arg->get_int("fil_w"); + size_t ho = gpu_conv_out_size(hi, py, dy, fy, sy); + size_t wo = gpu_conv_out_size(wi, px, dx, fx, sx); + size_t group = arg->get_int("group_count"); + + assert(c % group == 0 && k % group == 0); + + assert(group != 0 && c % group == 0 && k % group == 0); + + size_t k_per_group = k / group; + size_t c_per_group = c / group; + + if(c_per_group != kernel_info->k_per_block) + return false; + + if(k_per_group % kernel_info->n_per_block != 0) + return false; + + return true; + } + result_t run(const args_t *arg, hipModule_t module, igemm_fwd_btm_kernel_info_t * kernel_info, void *p_in, void *p_wei, void *p_out, int warmup, int repeat, const driverDataType_t& data_type) { + if(!is_valid(arg, kernel_info)){ + result_t result; + result.return_code = -1; + return result; + } size_t hi = arg->get_int("in_h"); size_t wi = arg->get_int("in_w"); size_t n = arg->get_int("batchsize"); @@ -200,18 +259,29 @@ class igemm_fwd_btm_t { HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_info->kernel_name.c_str())); int block_size = kernel_info->block_size; - int grid_size = kernel_info->occupancy * num_cu; + int num_gemm_m = (ho * wo + kernel_info->m_per_block - 1) / kernel_info->m_per_block; + int num_gemm_n = (k_per_group + kernel_info->n_per_block - 1) / kernel_info->n_per_block; + + int grid_size = kernel_info->occupancy * num_cu; grid_size = env_get_int("GRID_SIZE", grid_size); - int b_grids = (ho * wo + kernel_info->m_per_block - 1) / kernel_info->m_per_block; + if(grid_size % num_gemm_n == 0){ + int grids_for_m = grid_size / num_gemm_n; + karg.batch_m = (num_gemm_m + grids_for_m - 1) / grids_for_m; + karg.stride_m = kernel_info->m_per_block * grids_for_m; - karg.batch_m = (b_grids + grid_size - 1) / grid_size; - karg.stride_m = kernel_info->m_per_block * grid_size; + }else{ + grid_size = num_gemm_m * num_gemm_n; + karg.batch_m = 1; + karg.stride_m = 0; + } magic_div_u32_t mdiv_0 = magic_div_u32_gen(fx); magic_div_u32_t mdiv_1 = magic_div_u32_gen(wo); + magic_div_u32_t mdiv_2 = magic_div_u32_gen(num_gemm_n); karg.magic_0 = mdiv_0.magic; karg.magic_1 = mdiv_1.magic; - karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, 0, 0); + karg.magic_2 = mdiv_2.magic; + karg.shift_pack_0 = magic_div_u32_pack_shift(mdiv_0.shift, mdiv_1.shift, mdiv_2.shift, 0); // printf("launch fwd block:%d, grid:%d\n", block_size, grid_size); // dump_igemm_fwd_btm_2d_karg(&karg); diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm new file mode 100644 index 00000000..ff8d2218 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16.asm @@ -0,0 +1,493 @@ +; pay attention to register bank of v_c, v_b +.macro .fma_1x16_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] + v_dot2c_f32_f16 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] + v_dot2c_f32_f16 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] + v_dot2c_f32_f16 v[\v_c+10], v[\v_a], v[\v_b+10] + v_dot2c_f32_f16 v[\v_c+11], v[\v_a], v[\v_b+11] + v_dot2c_f32_f16 v[\v_c+12], v[\v_a], v[\v_b+12] + v_dot2c_f32_f16 v[\v_c+13], v[\v_a], v[\v_b+13] + v_dot2c_f32_f16 v[\v_c+14], v[\v_a], v[\v_b+14] + v_dot2c_f32_f16 v[\v_c+15], v[\v_a], v[\v_b+15] +.endm + +.macro .fma_1x8_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] +.endm + +.macro .fma_1x4_fp16 v_c, v_a, v_b + v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] +.endm + +.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp + s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] + s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] + s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] +.endm + +.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp + .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp + s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] + s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] +.endm + +.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp + v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] + v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] + v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] +.endm + +.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp + .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp + v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] + v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] +.endm + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.include "igemm_fwd_btm_nhwc_fp16_128x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_128x016.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x016.asm" +.include "igemm_fwd_btm_nhwc_fp16_256x008.asm" +.include "igemm_fwd_btm_nhwc_fp16_384x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_512x004.asm" +.include "igemm_fwd_btm_nhwc_fp16_512x008.asm" +.include "igemm_fwd_btm_nhwc_fp16_1024x008.asm" + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_fwd_btm_nhwc_fp16_128x4x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_128x4x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 88 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_128x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_128x16x16_r3.kd + .sgpr_count: 60 + .vgpr_count: 74 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_256x4x16_r1.kd + .sgpr_count: 60 + .vgpr_count: 112 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x16x16_r3 + .symbol: igemm_fwd_btm_nhwc_fp16_256x16x16_r3.kd + .sgpr_count: 60 + .vgpr_count: 112 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 13056 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x8x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_256x8x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 128 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_256x8x8_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_256x8x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_384x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_384x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 114 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_512x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 140 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x8x16_r2 + .symbol: igemm_fwd_btm_nhwc_fp16_512x8x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 188 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_512x8x8_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_512x8x8_r1.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 + .symbol: igemm_fwd_btm_nhwc_fp16_1024x8x8_r1.kd + .sgpr_count: 64 + .vgpr_count: 212 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm new file mode 100644 index 00000000..d7b69e86 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_1024x008.asm @@ -0,0 +1,1266 @@ +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 97 +.set v_ib, 129 +.set v_b, 130 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 170 +.set v_in_iwi, 178 +.set v_in_flag, 186 +.set v_out_os, 194 +.set v_out_flag, 202 +.set v_tid, 210 +.set v_end, 212 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_1024x8x8_r1,@function +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 10 ; 1024 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ; s_lshl_b32 s[s_wei_offset], s[s_c], 4+1 ; 16x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ; v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+4] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+4], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+5], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+6], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+7], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+4], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+4], 1, v[v_out_os+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+5], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+5], 1, v[v_out_os+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+6], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+6], 1, v[v_out_os+6] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + + v_mul_lo_u32 v[v_out_os+7], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+7], 1, v[v_out_os+7] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .v_clear_nc v_ay+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+24, 8 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ax +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ax +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ax +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ax +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ax +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ax +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ax +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ax +29, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ax +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ax +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ax +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ax +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ax +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ax +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ax +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ax +31, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .v_clear_nc v_ax+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+24, 8 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ay +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ay +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ay +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ay +29, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ay +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ay +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ay +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ay +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ay +31, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_add_nc_u32 v[v_in_flag+4], s[s_ib_stride], v[v_in_flag+3] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_in_flag+4,s_magic_1,s_shift_m1,s_wo,v_in_os+4 + v_add_nc_u32 v[v_in_flag+5], s[s_ib_stride], v[v_in_flag+4] + .v_clear_nc v_ax+16, 4 + v_mul_u32_u24 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_in_flag+5,s_magic_1,s_shift_m1,s_wo,v_in_os+5 + v_add_nc_u32 v[v_in_flag+6], s[s_ib_stride], v[v_in_flag+5] + v_mul_u32_u24 v[v_in_os+4], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_in_os+4], v[v_in_iwi+4], v[v_in_os+4] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_in_os+4] + + v_mul_u32_u24 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+5], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_in_os+5], v[v_in_iwi+5], v[v_in_os+5] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_in_os+5] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_in_flag+6,s_magic_1,s_shift_m1,s_wo,v_in_os+6 + v_add_nc_u32 v[v_in_flag+7], s[s_ib_stride], v[v_in_flag+6] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_in_flag+7,s_magic_1,s_shift_m1,s_wo,v_in_os+7 + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+6], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_in_os+6], v[v_in_iwi+6], v[v_in_os+6] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_in_os+6] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+7], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_in_os+7], v[v_in_iwi+7], v[v_in_os+7] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_in_os+7] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_fp16 v_c+32, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+40, v_ay +20, v_b + 0 + .fma_1x8_fp16 v_c+48, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+56, v_ay +28, v_b + 0 + + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + .fma_1x8_fp16 v_c+32, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+40, v_ay +21, v_b + 8 + .fma_1x8_fp16 v_c+48, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+56, v_ay +29, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_fp16 v_c+32, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+40, v_ay +22, v_b +16 + .fma_1x8_fp16 v_c+48, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+56, v_ay +30, v_b +16 + + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + .fma_1x8_fp16 v_c+32, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+40, v_ay +23, v_b +24 + .fma_1x8_fp16 v_c+48, v_ay +27, v_b +24 + .fma_1x8_fp16 v_c+56, v_ay +31, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + v_cvt_f16_f32 v[v_c +32], v[v_c +32] + v_cvt_f16_f32 v[v_c +33], v[v_c +33] + v_cvt_f16_f32 v[v_c +34], v[v_c +34] + v_cvt_f16_f32 v[v_c +35], v[v_c +35] + v_cvt_f16_f32 v[v_c +36], v[v_c +36] + v_cvt_f16_f32 v[v_c +37], v[v_c +37] + v_cvt_f16_f32 v[v_c +38], v[v_c +38] + v_cvt_f16_f32 v[v_c +39], v[v_c +39] + + v_cvt_f16_f32 v[v_c +40], v[v_c +40] + v_cvt_f16_f32 v[v_c +41], v[v_c +41] + v_cvt_f16_f32 v[v_c +42], v[v_c +42] + v_cvt_f16_f32 v[v_c +43], v[v_c +43] + v_cvt_f16_f32 v[v_c +44], v[v_c +44] + v_cvt_f16_f32 v[v_c +45], v[v_c +45] + v_cvt_f16_f32 v[v_c +46], v[v_c +46] + v_cvt_f16_f32 v[v_c +47], v[v_c +47] + + + v_pack_b32_f16 v[v_c_buf+16], v[v_c+32], v[v_c+33] + v_pack_b32_f16 v[v_c_buf+17], v[v_c+34], v[v_c+35] + v_pack_b32_f16 v[v_c_buf+18], v[v_c+36], v[v_c+37] + v_pack_b32_f16 v[v_c_buf+19], v[v_c+38], v[v_c+39] + + v_pack_b32_f16 v[v_c_buf+20], v[v_c+40], v[v_c+41] + v_pack_b32_f16 v[v_c_buf+21], v[v_c+42], v[v_c+43] + v_pack_b32_f16 v[v_c_buf+22], v[v_c+44], v[v_c+45] + v_pack_b32_f16 v[v_c_buf+23], v[v_c+46], v[v_c+47] + + v_cmpx_le_u32 1, v[v_out_flag+4] + global_store_dwordx4 v[v_out_os+4], v[v_c_buf+16:v_c_buf+19], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+5] + global_store_dwordx4 v[v_out_os+5], v[v_c_buf+20:v_c_buf+23], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +48], v[v_c +48] + v_cvt_f16_f32 v[v_c +49], v[v_c +49] + v_cvt_f16_f32 v[v_c +50], v[v_c +50] + v_cvt_f16_f32 v[v_c +51], v[v_c +51] + v_cvt_f16_f32 v[v_c +52], v[v_c +52] + v_cvt_f16_f32 v[v_c +53], v[v_c +53] + v_cvt_f16_f32 v[v_c +54], v[v_c +54] + v_cvt_f16_f32 v[v_c +55], v[v_c +55] + + v_cvt_f16_f32 v[v_c +56], v[v_c +56] + v_cvt_f16_f32 v[v_c +57], v[v_c +57] + v_cvt_f16_f32 v[v_c +58], v[v_c +58] + v_cvt_f16_f32 v[v_c +59], v[v_c +59] + v_cvt_f16_f32 v[v_c +60], v[v_c +60] + v_cvt_f16_f32 v[v_c +61], v[v_c +61] + v_cvt_f16_f32 v[v_c +62], v[v_c +62] + v_cvt_f16_f32 v[v_c +63], v[v_c +63] + + + v_pack_b32_f16 v[v_c_buf+24], v[v_c+48], v[v_c+49] + v_pack_b32_f16 v[v_c_buf+25], v[v_c+50], v[v_c+51] + v_pack_b32_f16 v[v_c_buf+26], v[v_c+52], v[v_c+53] + v_pack_b32_f16 v[v_c_buf+27], v[v_c+54], v[v_c+55] + + v_pack_b32_f16 v[v_c_buf+28], v[v_c+56], v[v_c+57] + v_pack_b32_f16 v[v_c_buf+29], v[v_c+58], v[v_c+59] + v_pack_b32_f16 v[v_c_buf+30], v[v_c+60], v[v_c+61] + v_pack_b32_f16 v[v_c_buf+31], v[v_c+62], v[v_c+63] + + v_cmpx_le_u32 1, v[v_out_flag+6] + global_store_dwordx4 v[v_out_os+6], v[v_c_buf+24:v_c_buf+27], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+7] + global_store_dwordx4 v[v_out_os+7], v[v_c_buf+28:v_c_buf+31], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + v_add_nc_u32 v[v_out_os+4], s[s_out_stride], v[v_out_os+4] + v_add_nc_u32 v[v_out_os+5], s[s_out_stride], v[v_out_os+5] + v_add_nc_u32 v[v_out_os+6], s[s_out_stride], v[v_out_os+6] + v_add_nc_u32 v[v_out_os+7], s[s_out_stride], v[v_out_os+7] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_1024x8x8_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_1024x8x8_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 212 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm new file mode 100644 index 00000000..8b7df108 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x004.asm @@ -0,0 +1,666 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 8 +.set v_ax, 9 +.set v_ay, 25 +.set v_ib, 41 +.set v_b, 42 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 74 +.set v_in_ihi, 76 +.set v_in_iwi, 78 +.set v_in_flag, 80 +.set v_out_os, 82 +.set v_out_flag, 84 +.set v_tid, 86 +.set v_end, 88 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_128x4x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_128x4x16_r2,@function +igemm_fwd_btm_nhwc_fp16_128x4x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + + ; calculate wei offset, 4x16, 4 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 7 ; 128 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + ; v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 8 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ; v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_128x4x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 64 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_128x4x16_r2 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 88 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm similarity index 98% rename from test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm rename to test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm index 136a8afd..055d0cd9 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_128x016.asm +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_128x016.asm @@ -21,7 +21,8 @@ .set k_stride_m, 92 .set k_magic_0, 96 .set k_magic_1, 100 -.set k_shift_pack_0, 104 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 .set s_block_ib, 2 ; bx, ho*wo .set s_ka, 0 @@ -50,27 +51,28 @@ .set s_stride_m, 33 .set s_magic_0, 34 .set s_magic_1, 35 -.set s_shift_pack_0, 36 -.set s_shift_m0, 37 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 .set s_shift_m1, s_shift_pack_0 .set s_in_stride_wi, 12 .set s_in_stride_n, 13 .set s_wei_stride_k, 14 .set s_out_stride_wo, 15 -.set s_out_stride_n, 38 -.set s_in_diff_hi, 39 -.set s_in_diff_wi, 40 -.set s_dilation_w_x, 41 -.set s_move_slice_k_ix, 42 +.set s_out_stride_n, 39 +.set s_in_diff_hi, 40 +.set s_in_diff_wi, 41 +.set s_dilation_w_x, 42 +.set s_move_slice_k_ix, 43 .set s_kitr, 1 -.set s_wei_offset, 43 +.set s_wei_offset, 44 .set s_out_stride, s_wei_offset -.set s_sld_b_stride, 44 -.set s_br, 45 +.set s_sld_b_stride, 45 +.set s_br, 46 -.set s_tmp, 46 -.set s_end, 52 +.set s_tmp, 48 +.set s_end, 54 ; magic_0: x ; magic_1: wo @@ -113,7 +115,7 @@ igemm_fwd_btm_nhwc_fp16_128x16x16_r3: s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m - s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0 + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 v_mov_b32 v[v_tid], v0 ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 4 for yx, 2 for c @@ -574,7 +576,7 @@ L_igemm_fwd_btm_nhwc_fp16_128x16x16_r3_end: .amdhsa_system_sgpr_workgroup_id_z 1 .amdhsa_system_vgpr_workitem_id 0 .amdhsa_next_free_vgpr 74 - .amdhsa_next_free_sgpr 52 + .amdhsa_next_free_sgpr 54 .amdhsa_ieee_mode 0 .amdhsa_dx10_clamp 0 .amdhsa_wavefront_size32 1 diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm new file mode 100644 index 00000000..f304a586 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x004.asm @@ -0,0 +1,652 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 8 +.set v_ax, 9 +.set v_ay, 25 +.set v_ib, 41 +.set v_b, 42 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 74 +.set v_in_ihi, 76 +.set v_in_iwi, 78 +.set v_in_flag, 80 +.set v_out_os, 82 +.set v_out_flag, 84 +.set v_tid, 86 +.set v_end, 88 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_256x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + ; v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 8 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ; v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 88 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm new file mode 100644 index 00000000..0b880322 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x008.asm @@ -0,0 +1,1520 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 33 +.set v_ib, 49 +.set v_b, 50 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 116 +.set v_in_iwi, 118 +.set v_in_flag, 120 +.set v_out_os, 122 +.set v_out_flag, 124 +.set v_tid, 126 +.set v_end, 128 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x8x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x8x16_r2,@function +igemm_fwd_btm_nhwc_fp16_256x8x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax +11, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ax +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ax + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ax +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ax + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ax +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ax + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ax +15, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body + + +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + s_mov_b32 s[s_move_slice_k_ix], 0 + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + + + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x8x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x8x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 128 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + + + + + +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_256x8x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_256x8x8_r2,@function +igemm_fwd_btm_nhwc_fp16_256x8x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + + ; calculate wei offset, 8x8, 8 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_256x8x8_r2_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_256x8x8_r2 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm similarity index 98% rename from test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm rename to test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm index 60092b31..2725dc4a 100644 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16_256x016.asm +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_256x016.asm @@ -21,7 +21,8 @@ .set k_stride_m, 92 .set k_magic_0, 96 .set k_magic_1, 100 -.set k_shift_pack_0, 104 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 .set s_block_ib, 2 ; bx, ho*wo .set s_ka, 0 @@ -50,25 +51,26 @@ .set s_stride_m, 33 .set s_magic_0, 34 .set s_magic_1, 35 -.set s_shift_pack_0, 36 -.set s_shift_m0, 37 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 .set s_shift_m1, s_shift_pack_0 .set s_in_stride_wi, 12 .set s_in_stride_n, 13 .set s_wei_stride_k, 14 .set s_out_stride_wo, 15 -.set s_out_stride_n, 38 -.set s_in_diff_hi, 39 -.set s_in_diff_wi, 40 -.set s_dilation_w_x, 41 -.set s_move_slice_k_ix, 42 +.set s_out_stride_n, 39 +.set s_in_diff_hi, 40 +.set s_in_diff_wi, 41 +.set s_dilation_w_x, 42 +.set s_move_slice_k_ix, 43 .set s_kitr, 1 -.set s_wei_offset, 43 +.set s_wei_offset, 44 .set s_out_stride, s_wei_offset -.set s_sld_b_stride, 44 -.set s_br, 45 -.set s_ib_stride, 46 +.set s_sld_b_stride, 45 +.set s_br, 46 +.set s_ib_stride, 47 .set s_tmp, 48 .set s_end, 54 @@ -114,7 +116,7 @@ igemm_fwd_btm_nhwc_fp16_256x16x16_r3: s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m - s_load_dword s[s_shift_pack_0], s[s_ka+0:s_ka+1], 0+k_shift_pack_0 + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 v_mov_b32 v[v_tid], v0 s_mov_b32 s[s_ib_stride], 128 diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm new file mode 100644 index 00000000..09afaa7b --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_384x004.asm @@ -0,0 +1,779 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 12 +.set v_ax, 13 +.set v_ay, 37 +.set v_ib, 61 +.set v_b, 62 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 94 +.set v_in_ihi, 97 +.set v_in_iwi, 100 +.set v_in_flag, 103 +.set v_out_os, 106 +.set v_out_flag, 109 +.set v_tid, 112 +.set v_end, 114 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_384x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_384x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_384x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + s_mov_b32 s[s_block_ib], 384 + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_mul_i32 s[s_block_ib], s[s_tmp+5], s[s_block_ib] ; 384 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(6) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 12 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] +igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay+ 0, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+ 8, 8 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_waitcnt vmcnt(6) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ax +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ax +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ax +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ax +19, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ax +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ax +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ax +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ax +23, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] +igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + ;--- end move slice window + + .v_clear_nc v_ax+ 0, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+ 8, 8 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(6) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + ; v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+4:v_c_buf+5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 12 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_384x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_384x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 114 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm new file mode 100644 index 00000000..1b2cc4b7 --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x004.asm @@ -0,0 +1,886 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 49 +.set v_ib, 81 +.set v_b, 82 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 118 +.set v_in_iwi, 122 +.set v_in_flag, 126 +.set v_out_os, 130 +.set v_out_flag, 134 +.set v_tid, 138 +.set v_end, 140 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x4x16_r1,@function +igemm_fwd_btm_nhwc_fp16_512x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 4x32, 4 for k, 32 for yxc, 16 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 5, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 31, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 15, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ; s_mov_b32 s[s_tmp+5], 128*k_n_dword*4 ; stride for wei sst offset. 32 thread for gemm_k, each thread store 4 c, hence 32*4=128 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ax + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ax +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ax +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ax + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ax +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ax +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ax +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ax +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ax +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ax +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ax +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ax +27, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ax + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ax +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ax +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ax +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ax + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ax +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ax +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ax +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ax + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ax +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ax +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ax +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ax + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ax +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ax +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ax +31, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ay +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ay +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ay +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ay +27, v_b +12 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_waitcnt lgkmcnt(4) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ay +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ay +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ay +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ay +31, v_b +28 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + + .fma_1x4_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_fp16 v_c+ 4, v_ay + 8, v_b + 0 + .fma_1x4_fp16 v_c+ 8, v_ay +16, v_b + 0 + .fma_1x4_fp16 v_c+12, v_ay +24, v_b + 0 + + .fma_1x4_fp16 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_fp16 v_c+ 4, v_ay + 9, v_b + 4 + .fma_1x4_fp16 v_c+ 8, v_ay +17, v_b + 4 + .fma_1x4_fp16 v_c+12, v_ay +25, v_b + 4 + + .fma_1x4_fp16 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_fp16 v_c+ 4, v_ay +10, v_b + 8 + .fma_1x4_fp16 v_c+ 8, v_ay +18, v_b + 8 + .fma_1x4_fp16 v_c+12, v_ay +26, v_b + 8 + + .fma_1x4_fp16 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_fp16 v_c+ 4, v_ay +11, v_b +12 + .fma_1x4_fp16 v_c+ 8, v_ay +19, v_b +12 + .fma_1x4_fp16 v_c+12, v_ay +27, v_b +12 + + s_waitcnt lgkmcnt(0) + .fma_1x4_fp16 v_c+ 0, v_ay + 4, v_b +16 + .fma_1x4_fp16 v_c+ 4, v_ay +12, v_b +16 + .fma_1x4_fp16 v_c+ 8, v_ay +20, v_b +16 + .fma_1x4_fp16 v_c+12, v_ay +28, v_b +16 + + .fma_1x4_fp16 v_c+ 0, v_ay + 5, v_b +20 + .fma_1x4_fp16 v_c+ 4, v_ay +13, v_b +20 + .fma_1x4_fp16 v_c+ 8, v_ay +21, v_b +20 + .fma_1x4_fp16 v_c+12, v_ay +29, v_b +20 + + .fma_1x4_fp16 v_c+ 0, v_ay + 6, v_b +24 + .fma_1x4_fp16 v_c+ 4, v_ay +14, v_b +24 + .fma_1x4_fp16 v_c+ 8, v_ay +22, v_b +24 + .fma_1x4_fp16 v_c+12, v_ay +30, v_b +24 + + .fma_1x4_fp16 v_c+ 0, v_ay + 7, v_b +28 + .fma_1x4_fp16 v_c+ 4, v_ay +15, v_b +28 + .fma_1x4_fp16 v_c+ 8, v_ay +23, v_b +28 + .fma_1x4_fp16 v_c+12, v_ay +31, v_b +28 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+2:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+4:v_c_buf+5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx2 v[v_out_os+3], v[v_c_buf+6:v_c_buf+7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*5 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*6 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*7 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x4x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 140 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm new file mode 100644 index 00000000..ce44c39c --- /dev/null +++ b/test/inference/kernel/fp16/igemm_fwd_btm_nhwc_fp16_512x008.asm @@ -0,0 +1,1756 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 65 +.set v_ib, 97 +.set v_b, 98 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 166 +.set v_in_iwi, 170 +.set v_in_flag, 174 +.set v_out_os, 178 +.set v_out_flag, 182 +.set v_tid, 186 +.set v_end, 188 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x8x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x8x16_r2,@function +igemm_fwd_btm_nhwc_fp16_512x8x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 8 for yx, 2 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 ; 9 dword per row, 4 row + v_and_b32 v[v_tmp+5], 15, v0 + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_and_b32 v[v_wei_ic], 1, v0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_tmp+5] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_and_b32 v[v_wei_ie], 7, v[v_tmp+4] ; yx + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+1 ; 8x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*8*4 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+0:v_ay+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+4:v_ay+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ay+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+16:v_ay+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+20:v_ay+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+24:v_ay+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ay+28:v_ay+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +27, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ax + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ax +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ax +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ax +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ax + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ax +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ax +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ax +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ax + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ax +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ax +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ax +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ax + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ax +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ax +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ax +31, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ; s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 16 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + .v_clear_nc v_ax+16, 16 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +27, v_b +24 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_waitcnt lgkmcnt(8) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ay +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ay +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ay +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ay +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ay +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ay +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ay +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ay +31, v_b +56 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + + v_mov_b32 v[v_ay +16], v[v_ax +16] + v_mov_b32 v[v_ay +17], v[v_ax +17] + v_mov_b32 v[v_ay +18], v[v_ax +18] + v_mov_b32 v[v_ay +19], v[v_ax +19] + v_mov_b32 v[v_ay +20], v[v_ax +20] + v_mov_b32 v[v_ay +21], v[v_ax +21] + v_mov_b32 v[v_ay +22], v[v_ax +22] + v_mov_b32 v[v_ay +23], v[v_ax +23] + v_mov_b32 v[v_ay +24], v[v_ax +24] + v_mov_b32 v[v_ay +25], v[v_ax +25] + v_mov_b32 v[v_ay +26], v[v_ax +26] + v_mov_b32 v[v_ay +27], v[v_ax +27] + v_mov_b32 v[v_ay +28], v[v_ax +28] + v_mov_b32 v[v_ay +29], v[v_ax +29] + v_mov_b32 v[v_ay +30], v[v_ax +30] + v_mov_b32 v[v_ay +31], v[v_ax +31] +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+4:v_ax+7], v[v_in_os], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+1], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+1], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+16, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+20, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + .v_clear_nc v_ax+24, 4 + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + .v_clear_nc v_ax+28, 4 + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+16:v_ax+19], v[v_in_os+2], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+20:v_ax+23], v[v_in_os+2], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+24:v_ax+27], v[v_in_os+3], s[s_p_in:s_p_in+1] + global_load_dwordx4 v[v_ax+28:v_ax+31], v[v_in_os+3], s[s_p_in:s_p_in+1] offset:16 + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay +16, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +24, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay +17, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +25, v_b + 8 + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +18, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +26, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +19, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +27, v_b +24 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 4, v_b +32 + .fma_1x8_fp16 v_c+ 8, v_ay +12, v_b +32 + .fma_1x8_fp16 v_c+16, v_ay +20, v_b +32 + .fma_1x8_fp16 v_c+24, v_ay +28, v_b +32 + .fma_1x8_fp16 v_c+ 0, v_ay + 5, v_b +40 + .fma_1x8_fp16 v_c+ 8, v_ay +13, v_b +40 + .fma_1x8_fp16 v_c+16, v_ay +21, v_b +40 + .fma_1x8_fp16 v_c+24, v_ay +29, v_b +40 + .fma_1x8_fp16 v_c+ 0, v_ay + 6, v_b +48 + .fma_1x8_fp16 v_c+ 8, v_ay +14, v_b +48 + .fma_1x8_fp16 v_c+16, v_ay +22, v_b +48 + .fma_1x8_fp16 v_c+24, v_ay +30, v_b +48 + .fma_1x8_fp16 v_c+ 0, v_ay + 7, v_b +56 + .fma_1x8_fp16 v_c+ 8, v_ay +15, v_b +56 + .fma_1x8_fp16 v_c+16, v_ay +23, v_b +56 + .fma_1x8_fp16 v_c+24, v_ay +31, v_b +56 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*4 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*4 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*5 + 0*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*5 + 4*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*6 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*6 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*7 + 0*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*7 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x8x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x8x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 188 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + + +;---------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_c_buf, v_c +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_fp16_512x8x8_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_fp16_512x8x8_r1,@function +igemm_fwd_btm_nhwc_fp16_512x8x8_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + s_lshl_b32 s[s_block_in], s[s_block_in], 1 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp], v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 3, v[v_wei_ic] ; 8xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ; s_lshl_b32 s[s_wei_offset], s[s_c], 4+1 ; 16x s_c, half + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ; v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 64*k_n_dword*4 ; stride for wei sst offset. 16 thread for gemm_k, each thread store 4 c, hence 16*4=64 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + ;v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + s_lshl_b32 s[s_tmp+5], s[s_block_ik], 1 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+5] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ax +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ax +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ax +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ax +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax +0:v_ax +3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag] + global_load_dwordx4 v[v_ax+0:v_ax+3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_fp16 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_fp16 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_fp16 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_fp16 v_c+24, v_ay +12, v_b + 0 + .fma_1x8_fp16 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_fp16 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_fp16 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_fp16 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_fp16 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_fp16 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_fp16 v_c+16, v_ay +10, v_b +16 + .fma_1x8_fp16 v_c+24, v_ay +14, v_b +16 + .fma_1x8_fp16 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_fp16 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_fp16 v_c+16, v_ay +11, v_b +24 + .fma_1x8_fp16 v_c+24, v_ay +15, v_b +24 + + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + v_cvt_f16_f32 v[v_c + 0], v[v_c + 0] + v_cvt_f16_f32 v[v_c + 1], v[v_c + 1] + v_cvt_f16_f32 v[v_c + 2], v[v_c + 2] + v_cvt_f16_f32 v[v_c + 3], v[v_c + 3] + v_cvt_f16_f32 v[v_c + 4], v[v_c + 4] + v_cvt_f16_f32 v[v_c + 5], v[v_c + 5] + v_cvt_f16_f32 v[v_c + 6], v[v_c + 6] + v_cvt_f16_f32 v[v_c + 7], v[v_c + 7] + + v_cvt_f16_f32 v[v_c + 8], v[v_c + 8] + v_cvt_f16_f32 v[v_c + 9], v[v_c + 9] + v_cvt_f16_f32 v[v_c +10], v[v_c +10] + v_cvt_f16_f32 v[v_c +11], v[v_c +11] + v_cvt_f16_f32 v[v_c +12], v[v_c +12] + v_cvt_f16_f32 v[v_c +13], v[v_c +13] + v_cvt_f16_f32 v[v_c +14], v[v_c +14] + v_cvt_f16_f32 v[v_c +15], v[v_c +15] + + + v_pack_b32_f16 v[v_c_buf+0], v[v_c+ 0], v[v_c+ 1] + v_pack_b32_f16 v[v_c_buf+1], v[v_c+ 2], v[v_c+ 3] + v_pack_b32_f16 v[v_c_buf+2], v[v_c+ 4], v[v_c+ 5] + v_pack_b32_f16 v[v_c_buf+3], v[v_c+ 6], v[v_c+ 7] + + v_pack_b32_f16 v[v_c_buf+4], v[v_c+ 8], v[v_c+ 9] + v_pack_b32_f16 v[v_c_buf+5], v[v_c+10], v[v_c+11] + v_pack_b32_f16 v[v_c_buf+6], v[v_c+12], v[v_c+13] + v_pack_b32_f16 v[v_c_buf+7], v[v_c+14], v[v_c+15] + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cvt_f16_f32 v[v_c +16], v[v_c +16] + v_cvt_f16_f32 v[v_c +17], v[v_c +17] + v_cvt_f16_f32 v[v_c +18], v[v_c +18] + v_cvt_f16_f32 v[v_c +19], v[v_c +19] + v_cvt_f16_f32 v[v_c +20], v[v_c +20] + v_cvt_f16_f32 v[v_c +21], v[v_c +21] + v_cvt_f16_f32 v[v_c +22], v[v_c +22] + v_cvt_f16_f32 v[v_c +23], v[v_c +23] + + v_cvt_f16_f32 v[v_c +24], v[v_c +24] + v_cvt_f16_f32 v[v_c +25], v[v_c +25] + v_cvt_f16_f32 v[v_c +26], v[v_c +26] + v_cvt_f16_f32 v[v_c +27], v[v_c +27] + v_cvt_f16_f32 v[v_c +28], v[v_c +28] + v_cvt_f16_f32 v[v_c +29], v[v_c +29] + v_cvt_f16_f32 v[v_c +30], v[v_c +30] + v_cvt_f16_f32 v[v_c +31], v[v_c +31] + + + v_pack_b32_f16 v[v_c_buf+ 8], v[v_c+16], v[v_c+17] + v_pack_b32_f16 v[v_c_buf+ 9], v[v_c+18], v[v_c+19] + v_pack_b32_f16 v[v_c_buf+10], v[v_c+20], v[v_c+21] + v_pack_b32_f16 v[v_c_buf+11], v[v_c+22], v[v_c+23] + + v_pack_b32_f16 v[v_c_buf+12], v[v_c+24], v[v_c+25] + v_pack_b32_f16 v[v_c_buf+13], v[v_c+26], v[v_c+27] + v_pack_b32_f16 v[v_c_buf+14], v[v_c+28], v[v_c+29] + v_pack_b32_f16 v[v_c_buf+15], v[v_c+30], v[v_c+31] + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_fma_body +L_igemm_fwd_btm_nhwc_fp16_512x8x8_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_fp16_512x8x8_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm b/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm deleted file mode 100644 index d10f7dd2..00000000 --- a/test/inference/kernel/igemm_fwd_btm_nhwc_fp16.asm +++ /dev/null @@ -1,153 +0,0 @@ -; pay attention to register bank of v_c, v_b -.macro .fma_1x16_fp16 v_c, v_a, v_b - v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] - v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] - v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] - v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] - v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] - v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] - v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] - v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] - v_dot2c_f32_f16 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] - v_dot2c_f32_f16 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] - v_dot2c_f32_f16 v[\v_c+10], v[\v_a], v[\v_b+10] - v_dot2c_f32_f16 v[\v_c+11], v[\v_a], v[\v_b+11] - v_dot2c_f32_f16 v[\v_c+12], v[\v_a], v[\v_b+12] - v_dot2c_f32_f16 v[\v_c+13], v[\v_a], v[\v_b+13] - v_dot2c_f32_f16 v[\v_c+14], v[\v_a], v[\v_b+14] - v_dot2c_f32_f16 v[\v_c+15], v[\v_a], v[\v_b+15] -.endm - -.macro .fma_1x8_fp16 v_c, v_a, v_b - v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] - v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] - v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] - v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] - v_dot2c_f32_f16 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] - v_dot2c_f32_f16 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] - v_dot2c_f32_f16 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] - v_dot2c_f32_f16 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] -.endm - -.macro .fma_1x4_fp16 v_c, v_a, v_b - v_dot2c_f32_f16 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] - v_dot2c_f32_f16 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] - v_dot2c_f32_f16 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] - v_dot2c_f32_f16 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] -.endm - -.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp - s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] - s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] - s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] -.endm - -.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp - .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp - s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] - s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] -.endm - -.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp - v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] - v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] - v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] -.endm - -.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp - .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp - v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] - v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] -.endm - -.macro .v_clear_nc vid, num - _v = \vid - .rept \num - v_mov_b32 v[_v], 0 - _v = _v + 1 - .endr -.endm - -.include "igemm_fwd_btm_nhwc_fp16_128x016.asm" -.include "igemm_fwd_btm_nhwc_fp16_256x016.asm" - -.amdgpu_metadata ---- -amdhsa.version: [ 1, 0 ] -amdhsa.kernels: - - .name: igemm_fwd_btm_nhwc_fp16_128x16x16_r3 - .symbol: igemm_fwd_btm_nhwc_fp16_128x16x16_r3.kd - .sgpr_count: 58 - .vgpr_count: 74 - .kernarg_segment_align: 8 - .kernarg_segment_size: 112 - .group_segment_fixed_size: 13056 - .private_segment_fixed_size: 0 - .wavefront_size: 32 - .reqd_workgroup_size : [128, 1, 1] - .max_flat_workgroup_size: 128 - .args: - - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} - - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} - - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} - - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} - - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} - - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} - - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} - - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} - - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} - - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} - - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} - - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} - - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} - - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} - - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} - - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} - - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} - - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} - - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} - - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} - - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} - - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} - - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} - - { .name: shift_pack_0, .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} - - { .name: __pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} - - .name: igemm_fwd_btm_nhwc_fp16_256x16x16_r3 - .symbol: igemm_fwd_btm_nhwc_fp16_256x16x16_r3.kd - .sgpr_count: 60 - .vgpr_count: 112 - .kernarg_segment_align: 8 - .kernarg_segment_size: 112 - .group_segment_fixed_size: 13056 - .private_segment_fixed_size: 0 - .wavefront_size: 32 - .reqd_workgroup_size : [128, 1, 1] - .max_flat_workgroup_size: 128 - .args: - - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} - - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} - - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} - - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} - - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} - - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} - - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} - - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} - - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} - - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} - - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} - - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} - - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} - - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} - - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} - - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} - - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} - - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} - - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} - - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} - - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} - - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} - - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} - - { .name: shift_pack_0, .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} - - { .name: __pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} -... -.end_amdgpu_metadata diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm new file mode 100644 index 00000000..d8549814 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8.asm @@ -0,0 +1,424 @@ +; pay attention to register bank of v_c, v_b +.macro .fma_1x16_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot4c_i32_i8 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot4c_i32_i8 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot4c_i32_i8 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot4c_i32_i8 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] + v_dot4c_i32_i8 v[\v_c+8 ], v[\v_a], v[\v_b+8 ] + v_dot4c_i32_i8 v[\v_c+9 ], v[\v_a], v[\v_b+9 ] + v_dot4c_i32_i8 v[\v_c+10], v[\v_a], v[\v_b+10] + v_dot4c_i32_i8 v[\v_c+11], v[\v_a], v[\v_b+11] + v_dot4c_i32_i8 v[\v_c+12], v[\v_a], v[\v_b+12] + v_dot4c_i32_i8 v[\v_c+13], v[\v_a], v[\v_b+13] + v_dot4c_i32_i8 v[\v_c+14], v[\v_a], v[\v_b+14] + v_dot4c_i32_i8 v[\v_c+15], v[\v_a], v[\v_b+15] +.endm + +.macro .fma_1x8_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] + v_dot4c_i32_i8 v[\v_c+4 ], v[\v_a], v[\v_b+4 ] + v_dot4c_i32_i8 v[\v_c+5 ], v[\v_a], v[\v_b+5 ] + v_dot4c_i32_i8 v[\v_c+6 ], v[\v_a], v[\v_b+6 ] + v_dot4c_i32_i8 v[\v_c+7 ], v[\v_a], v[\v_b+7 ] +.endm + +.macro .fma_1x4_int8x4 v_c, v_a, v_b + v_dot4c_i32_i8 v[\v_c+0 ], v[\v_a], v[\v_b+0 ] + v_dot4c_i32_i8 v[\v_c+1 ], v[\v_a], v[\v_b+1 ] + v_dot4c_i32_i8 v[\v_c+2 ], v[\v_a], v[\v_b+2 ] + v_dot4c_i32_i8 v[\v_c+3 ], v[\v_a], v[\v_b+3 ] +.endm + +.macro .mdiv_u32_ss s_quot s_numer s_magic s_shift s_tmp + s_mul_hi_u32 s[\s_tmp], s[\s_magic], s[\s_numer] + s_add_u32 s[\s_tmp], s[\s_tmp], s[\s_numer] + s_lshr_b32 s[\s_quot], s[\s_tmp], s[\s_shift] +.endm + +.macro .mdiv_u32_rem_ss s_rem s_quot s_numer s_magic s_shift s_denom s_tmp + .mdiv_u32_ss \s_quot,\s_numer,\s_magic,\s_shift,\s_tmp + s_mul_i32 s[\s_tmp], s[\s_denom], s[\s_quot] + s_sub_u32 s[\s_rem], s[\s_numer], s[\s_tmp] +.endm + +.macro .mdiv_u32_vs v_quot v_numer s_magic s_shift v_tmp + v_mul_hi_u32 v[\v_tmp], s[\s_magic], v[\v_numer] + v_add_nc_u32 v[\v_tmp], v[\v_tmp], v[\v_numer] + v_lshrrev_b32 v[\v_quot], s[\s_shift], v[\v_tmp] +.endm + +.macro .mdiv_u32_rem_vs v_rem v_quot v_numer s_magic s_shift s_denom v_tmp + .mdiv_u32_vs \v_quot,\v_numer,\s_magic,\s_shift,\v_tmp + v_mul_lo_u32 v[\v_tmp], s[\s_denom], v[\v_quot] + v_sub_nc_u32 v[\v_rem], v[\v_numer], v[\v_tmp] +.endm + +.macro .pack_i8x4_i32_r1 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + v_or_b32 v[\v_d], v[\v_src+ 0], v[\v_src+ 3] + v_or3_b32 v[\v_d], v[\v_d], v[\v_src+ 1], v[\v_src+ 2] +.endm + +.macro .pack_i8x4_i32_r2 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + + v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] + v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] + + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + + v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] + + v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] + + v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] + v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] + + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] + v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] + + v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] + v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] +.endm + +;.macro .pack_i8x4_i32_r4 v_d, v_src, s_0xff +; v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] +; v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] +; v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] +; v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] +; v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] +; v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] +; v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] +; v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] +; +; v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] +; v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] +; v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] +; v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] +; v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] +; v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] +; v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] +; v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] +; +; v_and_b32 v[\v_src+ 8], s[\s_0xff], v[\v_src+ 8] +; v_and_b32 v[\v_src+ 9], s[\s_0xff], v[\v_src+ 9] +; v_and_b32 v[\v_src+10], s[\s_0xff], v[\v_src+10] +; v_lshlrev_b32 v[\v_src+11], 24, v[\v_src+11] +; v_lshlrev_b32 v[\v_src+ 9], 8, v[\v_src+ 9] +; v_lshlrev_b32 v[\v_src+10], 16, v[\v_src+10] +; v_or_b32 v[\v_d+ 2], v[\v_src+ 8], v[\v_src+11] +; v_or3_b32 v[\v_d+ 2], v[\v_d+ 2], v[\v_src+ 9], v[\v_src+10] +; +; v_and_b32 v[\v_src+12], s[\s_0xff], v[\v_src+12] +; v_and_b32 v[\v_src+13], s[\s_0xff], v[\v_src+13] +; v_and_b32 v[\v_src+14], s[\s_0xff], v[\v_src+14] +; v_lshlrev_b32 v[\v_src+15], 24, v[\v_src+15] +; v_lshlrev_b32 v[\v_src+13], 8, v[\v_src+13] +; v_lshlrev_b32 v[\v_src+14], 16, v[\v_src+14] +; v_or_b32 v[\v_d+ 3], v[\v_src+12], v[\v_src+15] +; v_or3_b32 v[\v_d+ 3], v[\v_d+ 3], v[\v_src+13], v[\v_src+14] +;.endm + +.macro .pack_i8x4_i32_r4 v_d, v_src, s_0xff + v_and_b32 v[\v_src+ 0], s[\s_0xff], v[\v_src+ 0] + v_lshlrev_b32 v[\v_src+ 3], 24, v[\v_src+ 3] + v_and_b32 v[\v_src+ 4], s[\s_0xff], v[\v_src+ 4] + v_lshlrev_b32 v[\v_src+ 7], 24, v[\v_src+ 7] + + v_and_b32 v[\v_src+ 8], s[\s_0xff], v[\v_src+ 8] + v_lshlrev_b32 v[\v_src+11], 24, v[\v_src+11] + v_and_b32 v[\v_src+12], s[\s_0xff], v[\v_src+12] + v_lshlrev_b32 v[\v_src+15], 24, v[\v_src+15] + + v_or_b32 v[\v_d+ 0], v[\v_src+ 0], v[\v_src+ 3] + v_or_b32 v[\v_d+ 1], v[\v_src+ 4], v[\v_src+ 7] + v_or_b32 v[\v_d+ 2], v[\v_src+ 8], v[\v_src+11] + + v_and_b32 v[\v_src+ 1], s[\s_0xff], v[\v_src+ 1] + v_or_b32 v[\v_d+ 3], v[\v_src+12], v[\v_src+15] + + v_and_b32 v[\v_src+ 2], s[\s_0xff], v[\v_src+ 2] + v_and_b32 v[\v_src+ 5], s[\s_0xff], v[\v_src+ 5] + v_and_b32 v[\v_src+ 6], s[\s_0xff], v[\v_src+ 6] + v_and_b32 v[\v_src+ 9], s[\s_0xff], v[\v_src+ 9] + v_and_b32 v[\v_src+10], s[\s_0xff], v[\v_src+10] + v_and_b32 v[\v_src+13], s[\s_0xff], v[\v_src+13] + v_and_b32 v[\v_src+14], s[\s_0xff], v[\v_src+14] + + v_lshlrev_b32 v[\v_src+ 1], 8, v[\v_src+ 1] + v_lshlrev_b32 v[\v_src+ 2], 16, v[\v_src+ 2] + + v_lshlrev_b32 v[\v_src+ 5], 8, v[\v_src+ 5] + v_lshlrev_b32 v[\v_src+ 6], 16, v[\v_src+ 6] + + v_lshlrev_b32 v[\v_src+ 9], 8, v[\v_src+ 9] + v_lshlrev_b32 v[\v_src+10], 16, v[\v_src+10] + + v_lshlrev_b32 v[\v_src+13], 8, v[\v_src+13] + v_lshlrev_b32 v[\v_src+14], 16, v[\v_src+14] + + v_or3_b32 v[\v_d+ 0], v[\v_d+ 0], v[\v_src+ 1], v[\v_src+ 2] + v_or3_b32 v[\v_d+ 1], v[\v_d+ 1], v[\v_src+ 5], v[\v_src+ 6] + v_or3_b32 v[\v_d+ 2], v[\v_d+ 2], v[\v_src+ 9], v[\v_src+10] + v_or3_b32 v[\v_d+ 3], v[\v_d+ 3], v[\v_src+13], v[\v_src+14] +.endm + + +.macro .v_clear_nc vid, num + _v = \vid + .rept \num + v_mov_b32 v[_v], 0 + _v = _v + 1 + .endr +.endm + +.include "igemm_fwd_btm_nhwc_int8_256x004.asm" +.include "igemm_fwd_btm_nhwc_int8_256x008.asm" +.include "igemm_fwd_btm_nhwc_int8_512x008.asm" +.include "igemm_fwd_btm_nhwc_int8_512x016.asm" +.include "igemm_fwd_btm_nhwc_int8_1024x016.asm" + +.amdgpu_metadata +--- +amdhsa.version: [ 1, 0 ] +amdhsa.kernels: + - .name: igemm_fwd_btm_nhwc_int8_256x4x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_256x4x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 108 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 1024 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [64, 1, 1] + .max_flat_workgroup_size: 64 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_256x8x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_256x8x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 80 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x8x16_r1 + .symbol: igemm_fwd_btm_nhwc_int8_512x8x16_r1.kd + .sgpr_count: 64 + .vgpr_count: 124 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 2048 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x16x8_r2 + .symbol: igemm_fwd_btm_nhwc_int8_512x16x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 140 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_512x16x16_r2 + .symbol: igemm_fwd_btm_nhwc_int8_512x16x16_r2.kd + .sgpr_count: 64 + .vgpr_count: 188 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} + - .name: igemm_fwd_btm_nhwc_int8_1024x16x8_r2 + .symbol: igemm_fwd_btm_nhwc_int8_1024x16x8_r2.kd + .sgpr_count: 64 + .vgpr_count: 244 + .kernarg_segment_align: 8 + .kernarg_segment_size: 112 + .group_segment_fixed_size: 4096 + .private_segment_fixed_size: 0 + .wavefront_size: 32 + .reqd_workgroup_size : [128, 1, 1] + .max_flat_workgroup_size: 128 + .args: + - { .name: p_in , .size: 8, .offset: 0, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_wei , .size: 8, .offset: 8, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: true} + - { .name: p_out , .size: 8, .offset: 16, .value_kind: global_buffer, .value_type: f32, .address_space: global, .is_const: false} + - { .name: hi , .size: 4, .offset: 24, .value_kind: by_value, .value_type: i32} + - { .name: wi , .size: 4, .offset: 28, .value_kind: by_value, .value_type: i32} + - { .name: n , .size: 4, .offset: 32, .value_kind: by_value, .value_type: i32} + - { .name: k , .size: 4, .offset: 36, .value_kind: by_value, .value_type: i32} + - { .name: c , .size: 4, .offset: 40, .value_kind: by_value, .value_type: i32} + - { .name: ho , .size: 4, .offset: 44, .value_kind: by_value, .value_type: i32} + - { .name: wo , .size: 4, .offset: 48, .value_kind: by_value, .value_type: i32} + - { .name: stride_h , .size: 4, .offset: 52, .value_kind: by_value, .value_type: i32} + - { .name: stride_w , .size: 4, .offset: 56, .value_kind: by_value, .value_type: i32} + - { .name: dilation_h, .size: 4, .offset: 60, .value_kind: by_value, .value_type: i32} + - { .name: dilation_w, .size: 4, .offset: 64, .value_kind: by_value, .value_type: i32} + - { .name: pad_h , .size: 4, .offset: 68, .value_kind: by_value, .value_type: i32} + - { .name: pad_w , .size: 4, .offset: 72, .value_kind: by_value, .value_type: i32} + - { .name: y , .size: 4, .offset: 76, .value_kind: by_value, .value_type: i32} + - { .name: x , .size: 4, .offset: 80, .value_kind: by_value, .value_type: i32} + - { .name: group , .size: 4, .offset: 84, .value_kind: by_value, .value_type: i32} + - { .name: batch_m , .size: 4, .offset: 88, .value_kind: by_value, .value_type: i32} + - { .name: stride_m , .size: 4, .offset: 92, .value_kind: by_value, .value_type: i32} + - { .name: magic_0 , .size: 4, .offset: 96, .value_kind: by_value, .value_type: i32} + - { .name: magic_1 , .size: 4, .offset: 100, .value_kind: by_value, .value_type: i32} + - { .name: magic_2 , .size: 4, .offset: 104, .value_kind: by_value, .value_type: i32} + - { .name: shift_pack_0, .size: 4, .offset: 108, .value_kind: by_value, .value_type: i32} +... +.end_amdgpu_metadata diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm new file mode 100644 index 00000000..c5bf9282 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_1024x016.asm @@ -0,0 +1,1081 @@ +;---------------------------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 128 +.set v_ax, 129 +.set v_ay, 145 +.set v_ib, 161 +.set v_b, 162 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 194 +.set v_in_ihi, 202 +.set v_in_iwi, 210 +.set v_in_flag, 218 +.set v_out_os, 226 +.set v_out_flag, 234 +.set v_tid, 242 +.set v_end, 244 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_1024x16x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_1024x16x8_r2,@function +igemm_fwd_btm_nhwc_int8_1024x16x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 2 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 10 ; 1024 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx2 v[v_gld_b+0:v_gld_b+1], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx2 v[v_gld_b+2:v_gld_b+3], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 16*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 2 c, hence 8*2=16 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + ;.v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + ;.v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+4], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+5], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+6], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+7], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + + + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+4], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+5], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+6], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + + v_mul_lo_u32 v[v_out_os+7], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*2 + + s_waitcnt vmcnt(8) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 128 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ay+ 0:v_ay+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ay+ 2:v_ay+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ay+ 4:v_ay+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ay+ 6:v_ay+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 4 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ay+ 8:v_ay+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ay+10:v_ay+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+12, 4 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ay+12:v_ay+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ay+14:v_ay+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ax + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ax + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ax + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ax +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ax +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ax +14, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ax + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ax + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ax + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ax +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ax +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ax +15, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_iwi+4], s[s_tmp], v[v_in_iwi+4] + v_add_nc_u32 v[v_in_iwi+5], s[s_tmp], v[v_in_iwi+5] + v_add_nc_u32 v[v_in_iwi+6], s[s_tmp], v[v_in_iwi+6] + v_add_nc_u32 v[v_in_iwi+7], s[s_tmp], v[v_in_iwi+7] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + v_add_nc_u32 v[v_in_os+4], s[s_tmp+1], v[v_in_os+4] + v_add_nc_u32 v[v_in_os+5], s[s_tmp+1], v[v_in_os+5] + v_add_nc_u32 v[v_in_os+6], s[s_tmp+1], v[v_in_os+6] + v_add_nc_u32 v[v_in_os+7], s[s_tmp+1], v[v_in_os+7] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] + v_add_nc_i32 v[v_in_ihi+4], s[s_dilation_h], v[v_in_ihi+4] + v_add_nc_i32 v[v_in_ihi+5], s[s_dilation_h], v[v_in_ihi+5] + v_add_nc_i32 v[v_in_ihi+6], s[s_dilation_h], v[v_in_ihi+6] + v_add_nc_i32 v[v_in_ihi+7], s[s_dilation_h], v[v_in_ihi+7] +igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + ;--- end move slice window + + .v_clear_nc v_ax, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 4 + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+12, 4 + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(8) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ay + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ay +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ay +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ay +14, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ay + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ay +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ay +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ay +15, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 2 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+2, 2 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 2 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+6, 2 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_add_nc_u32 v[v_in_flag+4], s[s_ib_stride], v[v_in_flag+3] + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + .mdiv_u32_rem_vs v_in_iwi+4,v_in_ihi+4,v_in_flag+4,s_magic_1,s_shift_m1,s_wo,v_in_os+4 + v_add_nc_u32 v[v_in_flag+5], s[s_ib_stride], v[v_in_flag+4] + v_mul_lo_u32 v[v_in_ihi+4], s[s_stride_h], v[v_in_ihi+4] + .v_clear_nc v_ax+8, 2 + v_sub_nc_i32 v[v_in_ihi+4], v[v_in_ihi+4], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+4], s[s_stride_w], v[v_in_iwi+4] + .v_clear_nc v_ax+10, 2 + v_sub_nc_i32 v[v_in_iwi+4], v[v_in_iwi+4], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+5,v_in_ihi+5,v_in_flag+5,s_magic_1,s_shift_m1,s_wo,v_in_os+5 + v_add_nc_u32 v[v_in_flag+6], s[s_ib_stride], v[v_in_flag+5] + v_mul_lo_u32 v[v_in_ihi+5], s[s_stride_h], v[v_in_ihi+5] + v_sub_nc_i32 v[v_in_ihi+5], v[v_in_ihi+5], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+5], s[s_stride_w], v[v_in_iwi+5] + v_sub_nc_i32 v[v_in_iwi+5], v[v_in_iwi+5], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+4], s[s_wi], v[v_in_ihi+4] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+4] + v_cndmask_b32 v[v_in_flag+4], 0, 1 + v_add_nc_u32 v[v_in_os+4], v[v_in_iwi+4], v[v_in_os+4] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+4] + v_cndmask_b32 v[v_in_flag+4], 0, v[v_in_flag+4] + v_mul_lo_u32 v[v_in_os+4], s[s_in_stride_wi], v[v_in_os+4] + + v_cmpx_le_u32 1, v[v_in_flag+4] + global_load_dwordx2 v[v_ax+ 8:v_ax+ 9], v[v_in_os+4], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+5], s[s_wi], v[v_in_ihi+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+5] + v_cndmask_b32 v[v_in_flag+5], 0, 1 + v_add_nc_u32 v[v_in_os+5], v[v_in_iwi+5], v[v_in_os+5] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+5] + v_cndmask_b32 v[v_in_flag+5], 0, v[v_in_flag+5] + v_mul_lo_u32 v[v_in_os+5], s[s_in_stride_wi], v[v_in_os+5] + + v_cmpx_le_u32 1, v[v_in_flag+5] + global_load_dwordx2 v[v_ax+10:v_ax+11], v[v_in_os+5], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + .mdiv_u32_rem_vs v_in_iwi+6,v_in_ihi+6,v_in_flag+6,s_magic_1,s_shift_m1,s_wo,v_in_os+6 + v_add_nc_u32 v[v_in_flag+7], s[s_ib_stride], v[v_in_flag+6] + v_mul_lo_u32 v[v_in_ihi+6], s[s_stride_h], v[v_in_ihi+6] + .v_clear_nc v_ax+12, 2 + v_sub_nc_i32 v[v_in_ihi+6], v[v_in_ihi+6], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+6], s[s_stride_w], v[v_in_iwi+6] + .v_clear_nc v_ax+14, 2 + v_sub_nc_i32 v[v_in_iwi+6], v[v_in_iwi+6], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+7,v_in_ihi+7,v_in_flag+7,s_magic_1,s_shift_m1,s_wo,v_in_os+7 + v_mul_lo_u32 v[v_in_ihi+7], s[s_stride_h], v[v_in_ihi+7] + v_sub_nc_i32 v[v_in_ihi+7], v[v_in_ihi+7], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+7], s[s_stride_w], v[v_in_iwi+7] + v_sub_nc_i32 v[v_in_iwi+7], v[v_in_iwi+7], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+6], s[s_wi], v[v_in_ihi+6] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+6] + v_cndmask_b32 v[v_in_flag+6], 0, 1 + v_add_nc_u32 v[v_in_os+6], v[v_in_iwi+6], v[v_in_os+6] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+6] + v_cndmask_b32 v[v_in_flag+6], 0, v[v_in_flag+6] + v_mul_lo_u32 v[v_in_os+6], s[s_in_stride_wi], v[v_in_os+6] + + v_cmpx_le_u32 1, v[v_in_flag+6] + global_load_dwordx2 v[v_ax+12:v_ax+13], v[v_in_os+6], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+7], s[s_wi], v[v_in_ihi+7] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+7] + v_cndmask_b32 v[v_in_flag+7], 0, 1 + v_add_nc_u32 v[v_in_os+7], v[v_in_iwi+7], v[v_in_os+7] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+7] + v_cndmask_b32 v[v_in_flag+7], 0, v[v_in_flag+7] + v_mul_lo_u32 v[v_in_os+7], s[s_in_stride_wi], v[v_in_os+7] + + v_cmpx_le_u32 1, v[v_in_flag+7] + global_load_dwordx2 v[v_ax+14:v_ax+15], v[v_in_os+7], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+ 16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+ 32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+ 48, v_ay + 6, v_b + 0 + .fma_1x16_int8x4 v_c+ 64, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+ 80, v_ay +10, v_b + 0 + .fma_1x16_int8x4 v_c+ 96, v_ay +12, v_b + 0 + .fma_1x16_int8x4 v_c+112, v_ay +14, v_b + 0 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+ 16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+ 32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+ 48, v_ay + 7, v_b +16 + .fma_1x16_int8x4 v_c+ 64, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+ 80, v_ay +11, v_b +16 + .fma_1x16_int8x4 v_c+ 96, v_ay +13, v_b +16 + .fma_1x16_int8x4 v_c+112, v_ay +15, v_b +16 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + .pack_i8x4_i32_r4 v_c_buf+16, v_c+64, s_0xff + .pack_i8x4_i32_r4 v_c_buf+20, v_c+80, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+4] + global_store_dwordx4 v[v_out_os+4], v[v_c_buf+16:v_c_buf+19], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+5] + global_store_dwordx4 v[v_out_os+5], v[v_c_buf+20:v_c_buf+23], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+24, v_c+96, s_0xff + .pack_i8x4_i32_r4 v_c_buf+28, v_c+112, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+6] + global_store_dwordx4 v[v_out_os+6], v[v_c_buf+24:v_c_buf+27], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+7] + global_store_dwordx4 v[v_out_os+7], v[v_c_buf+28:v_c_buf+31], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + .v_clear_nc v_c, 128 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + v_add_nc_u32 v[v_out_os+4], s[s_out_stride], v[v_out_os+4] + v_add_nc_u32 v[v_out_os+5], s[s_out_stride], v[v_out_os+5] + v_add_nc_u32 v[v_out_os+6], s[s_out_stride], v[v_out_os+6] + v_add_nc_u32 v[v_out_os+7], s[s_out_stride], v[v_out_os+7] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+4] + v_cndmask_b32 v[v_out_flag+4], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+5] + v_cndmask_b32 v[v_out_flag+5], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+6] + v_cndmask_b32 v[v_out_flag+6], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+7] + v_cndmask_b32 v[v_out_flag+7], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_1024x16x8_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_1024x16x8_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 244 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm new file mode 100644 index 00000000..01566d15 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x004.asm @@ -0,0 +1,738 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 4 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 82 +.set v_in_ihi, 86 +.set v_in_iwi, 90 +.set v_in_flag, 94 +.set v_out_os, 98 +.set v_out_flag, 102 +.set v_tid, 106 +.set v_end, 108 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_256x4x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_256x4x16_r1,@function +igemm_fwd_btm_nhwc_int8_256x4x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 64 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 4x16, 4 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 2 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 2 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ax + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ax + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ax +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ax + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ax + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ax + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ax +13, v_b + 4 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ax + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ax + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ax +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ax +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ax + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ax + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ax +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ax +15, v_b +12 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ay + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ay +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ay + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ay + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ay +13, v_b + 4 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ay + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ay +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ay +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ay + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ay +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ay +15, v_b +12 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt vmcnt(4) lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x4_int8x4 v_c+ 4, v_ay + 4, v_b + 0 + .fma_1x4_int8x4 v_c+ 8, v_ay + 8, v_b + 0 + .fma_1x4_int8x4 v_c+12, v_ay +12, v_b + 0 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 1, v_b + 4 + .fma_1x4_int8x4 v_c+ 4, v_ay + 5, v_b + 4 + .fma_1x4_int8x4 v_c+ 8, v_ay + 9, v_b + 4 + .fma_1x4_int8x4 v_c+12, v_ay +13, v_b + 4 + + s_waitcnt lgkmcnt(2) + .fma_1x4_int8x4 v_c+ 0, v_ay + 2, v_b + 8 + .fma_1x4_int8x4 v_c+ 4, v_ay + 6, v_b + 8 + .fma_1x4_int8x4 v_c+ 8, v_ay +10, v_b + 8 + .fma_1x4_int8x4 v_c+12, v_ay +14, v_b + 8 + + .fma_1x4_int8x4 v_c+ 0, v_ay + 3, v_b +12 + .fma_1x4_int8x4 v_c+ 4, v_ay + 7, v_b +12 + .fma_1x4_int8x4 v_c+ 8, v_ay +11, v_b +12 + .fma_1x4_int8x4 v_c+12, v_ay +15, v_b +12 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dword v[v_out_os], v[v_c_buf+0], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dword v[v_out_os+1], v[v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dword v[v_out_os+2], v[v_c_buf+2], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dword v[v_out_os+3], v[v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*1 + + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*2 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*3 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_256x4x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 64 +; r1 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_256x4x16_r1 + .amdhsa_group_segment_fixed_size 1024 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 108 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm new file mode 100644 index 00000000..b34ac185 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_256x008.asm @@ -0,0 +1,585 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 16 +.set v_ax, 17 +.set v_ay, 25 +.set v_ib, 33 +.set v_b, 34 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 66 +.set v_in_ihi, 68 +.set v_in_iwi, 70 +.set v_in_flag, 72 +.set v_out_os, 74 +.set v_out_flag, 76 +.set v_tid, 78 +.set v_end, 80 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_256x8x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_256x8x16_r1,@function +igemm_fwd_btm_nhwc_int8_256x8x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 8 ; 256 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + ;v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + ;v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(2) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 16 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(2) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ax + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ax + 5, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ax + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ax + 7, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] +igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(2) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + ;v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+ 2:v_c_buf+ 3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_256x8x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_256x8x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 80 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm new file mode 100644 index 00000000..4866b14c --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x008.asm @@ -0,0 +1,757 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 8 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 32 +.set v_ax, 33 +.set v_ay, 49 +.set v_ib, 65 +.set v_b, 66 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+4 +.set v_wei_ix_list, v_b+5 +.set v_wei_flag, v_b+6 +.set v_wei_os, v_b+7 +.set v_tmp, v_b+8 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 98 +.set v_in_ihi, 102 +.set v_in_iwi, 106 +.set v_in_flag, 110 +.set v_out_os, 114 +.set v_out_flag, 118 +.set v_tid, 122 +.set v_end, 124 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x8x16_r1 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x8x16_r1,@function +igemm_fwd_btm_nhwc_int8_512x8x16_r1: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 8x16, 8 for k, 16 for yxc, 16 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 4, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 15, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 3 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 3 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + ;s_lshl_b32 s[s_wei_offset], s[s_c], 4+0 ; 16x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + ;v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + ;v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + ;s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 32 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ax + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ax + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ax +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ax + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ax + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ax +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ax + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ax + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ax +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ax +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ax + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ax + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ax +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ax +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ay +13, v_b + 8 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ay +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ay +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ay +15, v_b +24 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x8_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x8_int8x4 v_c+ 8, v_ay + 4, v_b + 0 + .fma_1x8_int8x4 v_c+16, v_ay + 8, v_b + 0 + .fma_1x8_int8x4 v_c+24, v_ay +12, v_b + 0 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 1, v_b + 8 + .fma_1x8_int8x4 v_c+ 8, v_ay + 5, v_b + 8 + .fma_1x8_int8x4 v_c+16, v_ay + 9, v_b + 8 + .fma_1x8_int8x4 v_c+24, v_ay +13, v_b + 8 + + s_waitcnt lgkmcnt(0) + .fma_1x8_int8x4 v_c+ 0, v_ay + 2, v_b +16 + .fma_1x8_int8x4 v_c+ 8, v_ay + 6, v_b +16 + .fma_1x8_int8x4 v_c+16, v_ay +10, v_b +16 + .fma_1x8_int8x4 v_c+24, v_ay +14, v_b +16 + + .fma_1x8_int8x4 v_c+ 0, v_ay + 3, v_b +24 + .fma_1x8_int8x4 v_c+ 8, v_ay + 7, v_b +24 + .fma_1x8_int8x4 v_c+16, v_ay +11, v_b +24 + .fma_1x8_int8x4 v_c+24, v_ay +15, v_b +24 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx2 v[v_out_os], v[v_c_buf+0:v_c_buf+1], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx2 v[v_out_os+1], v[v_c_buf+ 2:v_c_buf+ 3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx2 v[v_out_os+2], v[v_c_buf+ 4:v_c_buf+ 5], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx2 v[v_out_os+3], v[v_c_buf+ 6:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + + .v_clear_nc v_c, 32 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_fma_body +L_igemm_fwd_btm_nhwc_int8_512x8x16_r1_end: + s_endpgm + +; LDS: 1 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x8x16_r1 + .amdhsa_group_segment_fixed_size 2048 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 124 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm new file mode 100644 index 00000000..79371df7 --- /dev/null +++ b/test/inference/kernel/int8/igemm_fwd_btm_nhwc_int8_512x016.asm @@ -0,0 +1,1545 @@ +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 81 +.set v_ib, 97 +.set v_b, 98 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 162 +.set v_in_ihi, 166 +.set v_in_iwi, 170 +.set v_in_flag, 174 +.set v_out_os, 178 +.set v_out_flag, 182 +.set v_tid, 186 +.set v_end, 188 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x16x16_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x16x16_r2,@function +igemm_fwd_btm_nhwc_int8_512x16x16_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 4 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx4 v[v_gld_b+0:v_gld_b+3], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx4 v[v_gld_b+4:v_gld_b+7], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 32*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 4 c, hence 8*4=32 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*4 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+4], v[v_gld_b+5], offset0:k_n_dword*0 offset1:k_n_dword*1 + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+6], v[v_gld_b+7], offset0:k_n_dword*2 offset1:k_n_dword*3 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ay+ 0:v_ay+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ay+ 4:v_ay+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ay+ 8:v_ay+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ay+12:v_ay+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ax + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ax +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ax + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ax +13, v_b +16 + + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ax + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ax + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ax +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ax +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ax + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ax + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ax +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ax +15, v_b +48 + + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 8 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+8, 8 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay +13, v_b +16 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ay + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ay +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ay +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ay + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ay +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ay +15, v_b +48 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 16 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + v_mov_b32 v[v_ay + 8], v[v_ax + 8] + v_mov_b32 v[v_ay + 9], v[v_ax + 9] + v_mov_b32 v[v_ay +10], v[v_ax +10] + v_mov_b32 v[v_ay +11], v[v_ax +11] + v_mov_b32 v[v_ay +12], v[v_ax +12] + v_mov_b32 v[v_ay +13], v[v_ax +13] + v_mov_b32 v[v_ay +14], v[v_ax +14] + v_mov_b32 v[v_ay +15], v[v_ax +15] + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx4 v[v_ax+ 0:v_ax+ 3], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx4 v[v_ax+ 4:v_ax+ 7], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+8, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx4 v[v_ax+ 8:v_ax+11], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx4 v[v_ax+12:v_ax+15], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(8) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 8, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay +12, v_b + 0 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 9, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay +13, v_b +16 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 2, v_b +32 + .fma_1x16_int8x4 v_c+16, v_ay + 6, v_b +32 + .fma_1x16_int8x4 v_c+32, v_ay +10, v_b +32 + .fma_1x16_int8x4 v_c+48, v_ay +14, v_b +32 + + .fma_1x16_int8x4 v_c+ 0, v_ay + 3, v_b +48 + .fma_1x16_int8x4 v_c+16, v_ay + 7, v_b +48 + .fma_1x16_int8x4 v_c+32, v_ay +11, v_b +48 + .fma_1x16_int8x4 v_c+48, v_ay +15, v_b +48 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + ds_read_b128 v[v_b+32:v_b+35], v[v_sld_b_os], offset:k_n_dword*4*2 + 0*4 + ds_read_b128 v[v_b+36:v_b+39], v[v_sld_b_os], offset:k_n_dword*4*2 + 4*4 + ds_read_b128 v[v_b+40:v_b+43], v[v_sld_b_os], offset:k_n_dword*4*2 + 8*4 + ds_read_b128 v[v_b+44:v_b+47], v[v_sld_b_os], offset:k_n_dword*4*2 +12*4 + ds_read_b128 v[v_b+48:v_b+51], v[v_sld_b_os], offset:k_n_dword*4*3 + 0*4 + ds_read_b128 v[v_b+52:v_b+55], v[v_sld_b_os], offset:k_n_dword*4*3 + 4*4 + ds_read_b128 v[v_b+56:v_b+59], v[v_sld_b_os], offset:k_n_dword*4*3 + 8*4 + ds_read_b128 v[v_b+60:v_b+63], v[v_sld_b_os], offset:k_n_dword*4*3 +12*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 16 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_512x16x16_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x16x16_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 188 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel + +;---------------------------------------------------------------------------------- +.set k_p_in, 0 +.set k_p_wei, 8 +.set k_p_out, 16 +.set k_hi, 24 +.set k_wi, 28 +.set k_n, 32 +.set k_k, 36 +.set k_c, 40 +.set k_ho, 44 +.set k_wo, 48 +.set k_stride_h, 52 +.set k_stride_w, 56 +.set k_dilation_h, 60 +.set k_dilation_w, 64 +.set k_pad_h, 68 +.set k_pad_w, 72 +.set k_y, 76 +.set k_x, 80 +.set k_group, 84 +.set k_batch_m, 88 +.set k_stride_m, 92 +.set k_magic_0, 96 +.set k_magic_1, 100 +.set k_magic_2, 104 +.set k_shift_pack_0, 108 +.set k_n_dword, 16 + +.set s_ka, 0 +.set s_bx, 2 ; bx, ho*wo +.set s_block_ig, 3 ; by, group +.set s_block_in, 4 ; bz, batch +.set s_p_in, 6 +.set s_p_wei, 8 +.set s_p_out, 10 +.set s_hi, 16 +.set s_wi, 17 +.set s_n, 18 +.set s_k, 19 +.set s_c, 20 +.set s_ho, 21 +.set s_wo, 22 +.set s_stride_h, 23 +.set s_stride_w, 24 +.set s_dilation_h, 25 +.set s_dilation_w, 26 +.set s_pad_h, 27 +.set s_pad_w, 28 +.set s_y, 29 +.set s_x, 30 +.set s_group, 31 +.set s_batch_m, 32 +.set s_stride_m, 33 +.set s_magic_0, 34 +.set s_magic_1, 35 +.set s_magic_2, 36 +.set s_shift_pack_0, 37 +.set s_shift_m0, 38 +.set s_shift_m1, s_shift_pack_0 +.set s_shift_m2, 39 +.set s_in_stride_wi, 12 +.set s_in_stride_n, 13 +.set s_wei_stride_k, 14 +.set s_out_stride_wo, 15 +.set s_out_stride_n, 40 +.set s_in_diff_hi, 41 +.set s_in_diff_wi, 42 +.set s_dilation_w_x, 43 +.set s_move_slice_k_ix, 44 + +.set s_kitr, 1 +.set s_wei_offset, 45 +.set s_out_stride, s_wei_offset +.set s_sld_b_stride, 46 +.set s_br, 47 +.set s_ib_stride, 48 +.set s_block_ik, 49 +.set s_block_ib, 50 +.set s_0xff, 51 +.set s_tmp, 52 +.set s_end, 58 + +; magic_0: x +; magic_1: wo + +.set v_c, 0 +.set v_sld_b_os, 64 +.set v_ax, 65 +.set v_ay, 73 +.set v_ib, 81 +.set v_b, 82 +.set v_gld_b, v_b +.set v_wei_iy_list, v_b+8 +.set v_wei_ix_list, v_b+10 +.set v_wei_flag, v_b+12 +.set v_wei_os, v_b+14 +.set v_tmp, v_b+16 +.set v_wei_ik, v_ay +.set v_wei_ic, v_ay+1 +.set v_wei_ie, v_ay+2 +.set v_wei_flag_ik, v_ay+3 +.set v_sst_b_os, v_ay+4 +.set v_in_os, 114 +.set v_in_ihi, 118 +.set v_in_iwi, 122 +.set v_in_flag, 126 +.set v_out_os, 130 +.set v_out_flag, 134 +.set v_tid, 138 +.set v_end, 140 +.set v_c_buf, v_b + +; short wide igemv +.text +.globl igemm_fwd_btm_nhwc_int8_512x16x8_r2 +.p2align 8 + +.type igemm_fwd_btm_nhwc_int8_512x16x8_r2,@function +igemm_fwd_btm_nhwc_int8_512x16x8_r2: + s_load_dwordx2 s[s_p_in+0:s_p_in+1], s[s_ka+0:s_ka+1], 0+k_p_in + s_load_dwordx4 s[s_p_wei+0:s_p_wei+3], s[s_ka+0:s_ka+1], 0+k_p_wei + s_load_dwordx16 s[s_hi+0:s_hi+15], s[s_ka+0:s_ka+1], 0+k_hi + s_load_dwordx4 s[s_batch_m:s_batch_m+3], s[s_ka+0:s_ka+1], 0+k_batch_m + s_load_dwordx2 s[s_magic_2:s_magic_2+1], s[s_ka+0:s_ka+1], 0+k_magic_2 + v_mov_b32 v[v_tid], v0 + s_mov_b32 s[s_ib_stride], 128 + s_mov_b32 s[s_0xff], 0xff + + ; calculate wei offset, 16x8, 16 for k, 8 for yxc, 8 for yx, 1 for c + v_lshrrev_b32 v[v_wei_ik], 3, v0 + s_mov_b32 s[s_tmp], k_n_dword*4 * 2 + v_and_b32 v[v_wei_ie], 7, v0 ; yx + ;s_lshl_b32 s[s_block_ig], s[s_block_ig], 1 + v_mov_b32 v[v_wei_ic], 0 + ;s_lshl_b32 s[s_block_in], s[s_block_in], 1 + ;v_lshrrev_b32 v[v_tmp+4], 1, v0 + v_mov_b32 v[v_ib], v0 + v_mul_u32_u24 v[v_tmp+5], s[s_tmp] ,v[v_wei_ie] + v_lshlrev_b32 v[v_sst_b_os], 2, v[v_wei_ik] ; store, k*n*k_pack, ds_write2 if possible, n*k_pack->16dword, pad to x + v_mov_b32 v[v_sld_b_os], 0 ; load + v_lshlrev_b32 v[v_wei_ic], 4, v[v_wei_ic] ; 16xc, k_pack, 4x dword + v_add_nc_u32 v[v_sst_b_os], v[v_sst_b_os], v[v_tmp+5] ; note, do not use or due to pad + + s_waitcnt lgkmcnt(0) + s_bfe_u32 s[s_shift_m2], s[s_shift_pack_0], 0x00080010 ; offset:16, width:8 + s_lshr_b32 s[s_tmp+3], s[s_k], 4 + s_bfe_u32 s[s_shift_m0], s[s_shift_pack_0], 0x00080000 ; offset:0, width:8 + .mdiv_u32_rem_ss s_tmp+4,s_tmp+5,s_bx,s_magic_2,s_shift_m2,s_tmp+3,s_tmp + s_lshl_b32 s[s_block_ib], s[s_tmp+5], 9 ; 512 + s_lshl_b32 s[s_block_ik], s[s_tmp+4], 4 + v_add_nc_u32 v[v_ib], s[s_block_ib], v[v_ib] + s_mul_i32 s[s_tmp], s[s_x], s[s_c] + v_add_nc_u32 v[v_wei_ik], s[s_block_ik], v[v_wei_ik] + + v_mad_u32_u24 v[v_tmp+1], s[s_c], v[v_wei_ie], v[v_wei_ic] + s_mul_i32 s[s_wei_stride_k], s[s_tmp], s[s_y] + s_lshl_b32 s[s_wei_offset], s[s_c], 3+0 ; 8x s_c, int8 + s_mul_i32 s[s_tmp+5], s[s_wei_stride_k], s[s_k] + v_mad_u32_u24 v[v_wei_os], s[s_wei_stride_k], v[v_wei_ik], v[v_tmp+1] + s_mul_i32 s[s_tmp+2], s[s_block_ig], s[s_tmp+5] + v_cmp_gt_u32 s[s_k], v[v_wei_ik] + s_add_u32 s[s_p_wei], s[s_p_wei], s[s_tmp+2] + v_cndmask_b32 v[v_wei_flag_ik], 0, 1 + s_addc_u32 s[s_p_wei+1], s[s_p_wei+1], 0 + ;v_lshlrev_b32 v[v_wei_os], 1, v[v_wei_os] + + ; divide x + .mdiv_u32_rem_vs v_wei_ix_list+0,v_wei_iy_list+0,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_add_nc_u32 v[v_wei_os+1], s[s_wei_offset], v[v_wei_os+0] + v_add_nc_u32 v[v_wei_ie], 8, v[v_wei_ie] + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+0] + v_cndmask_b32 v[v_wei_flag+0], 0, v[v_wei_flag+0] + + .mdiv_u32_rem_vs v_wei_ix_list+1,v_wei_iy_list+1,v_wei_ie,s_magic_0,s_shift_m0,s_x,v_tmp + v_cmp_gt_u32 s[s_y], v[v_wei_iy_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag_ik] + v_cmp_gt_u32 s[s_x], v[v_wei_ix_list+1] + v_cndmask_b32 v[v_wei_flag+1], 0, v[v_wei_flag+1] + + v_cmpx_le_u32 1, v[v_wei_flag+0] + global_load_dwordx2 v[v_gld_b+0:v_gld_b+1], v[v_wei_os+0], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + global_load_dwordx2 v[v_gld_b+2:v_gld_b+3], v[v_wei_os+1], s[s_p_wei:s_p_wei+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_tmp+5], 16*k_n_dword*4 ; stride for wei sst offset. 8 thread for gemm_k, each thread store 2 c, hence 8*2=16 gemm_k + + ; calculate in offset + s_mul_i32 s[s_in_stride_wi], s[s_c], s[s_group] + s_bfe_u32 s[s_shift_m1], s[s_shift_pack_0], 0x00080008 ; offset:8, width:8 + s_mul_i32 s[s_tmp+2], s[s_wi], s[s_in_stride_wi] + s_mul_i32 s[s_tmp+0], s[s_block_ig], s[s_c] + s_mul_i32 s[s_in_stride_n], s[s_hi], s[s_tmp+2] + s_mul_i32 s[s_tmp+3], s[s_block_in], s[s_in_stride_n] + ;s_lshl_b32 s[s_in_stride_wi], s[s_in_stride_wi], 1 + s_add_u32 s[s_tmp+0], s[s_tmp+0], s[s_tmp+3] + v_add_nc_u32 v[v_sst_b_os+1], s[s_tmp+5], v[v_sst_b_os+0] + + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_tmp + s_add_u32 s[s_p_in], s[s_p_in], s[s_tmp+0] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_addc_u32 s[s_p_in+1], s[s_p_in+1], 0 + v_mul_lo_u32 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 4 + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + ;.v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi] + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_tmp] + + v_mul_lo_u32 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+1], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+5] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 4 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + ;.v_clear_nc v_ax+12, 4 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_tmp+5,s_magic_1,s_shift_m1,s_wo,v_tmp + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+2], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_tmp], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_tmp], v[v_in_iwi+3], v[v_tmp] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_tmp] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mul_i32 s[s_br], s[s_wo], s[s_ho] + + s_mul_i32 s[s_out_stride_wo], s[s_k], s[s_group] + s_mul_i32 s[s_in_diff_wi], s[s_dilation_w], s[s_in_stride_wi] + s_mov_b32 s[s_move_slice_k_ix], 0 + + s_mul_i32 s[s_out_stride_n], s[s_br], s[s_out_stride_wo] + s_mul_i32 s[s_tmp+1], s[s_block_ig], s[s_k] + s_mul_i32 s[s_tmp+4], s[s_block_in], s[s_out_stride_n] + ;s_lshl_b32 s[s_tmp+5], s[s_block_ik], 0 + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_tmp+4] + s_add_u32 s[s_tmp+1], s[s_tmp+1], s[s_block_ik] + s_add_u32 s[s_p_out], s[s_p_out], s[s_tmp+1] + s_addc_u32 s[s_p_out+1], s[s_p_out+1], 0 + + ; calculate diffs, for y, x + s_sub_i32 s[s_tmp+3], s[s_x], 1 + s_mul_i32 s[s_tmp], s[s_in_diff_wi], s[s_tmp+3] + s_mul_i32 s[s_tmp+1], s[s_in_stride_wi], s[s_wi] + s_mul_i32 s[s_tmp+1], s[s_tmp+1], s[s_dilation_h] + s_sub_i32 s[s_in_diff_hi], s[s_tmp+1], s[s_tmp] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w], s[s_tmp+3] + s_mul_i32 s[s_dilation_w_x], s[s_dilation_w_x], -1 + + + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_ib] + s_mul_i32 s[s_out_stride], s[s_stride_m], s[s_out_stride_wo] + + ;s_lshl_b32 s[s_out_stride], s[s_out_stride], 1 + ;s_lshl_b32 s[s_out_stride_n], s[s_out_stride_n], 1 + + ; output offset + v_mul_lo_u32 v[v_out_os], s[s_k], v[v_ib] + ;v_lshlrev_b32 v[v_out_os], 1, v[v_out_os] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + v_add_nc_u32 v[v_tmp+4], s[s_ib_stride], v[v_tmp+5] + + v_mul_lo_u32 v[v_out_os+1], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+1], 1, v[v_out_os+1] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_add_nc_u32 v[v_tmp+5], s[s_ib_stride], v[v_tmp+4] + + v_mul_lo_u32 v[v_out_os+2], s[s_k], v[v_tmp+4] + ;v_lshlrev_b32 v[v_out_os+2], 1, v[v_out_os+2] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + + v_mul_lo_u32 v[v_out_os+3], s[s_k], v[v_tmp+5] + ;v_lshlrev_b32 v[v_out_os+3], 1, v[v_out_os+3] + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_mov_b32 s[s_sld_b_stride], k_n_dword*4*2 + + s_waitcnt vmcnt(4) + + v_cmpx_le_u32 1, v[v_wei_flag+0] + ds_write2_b32 v[v_sst_b_os+0], v[v_gld_b+0], v[v_gld_b+1], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_wei_flag+1] + ds_write2_b32 v[v_sst_b_os+1], v[v_gld_b+2], v[v_gld_b+3], offset0:k_n_dword*0 offset1:k_n_dword*1 + s_mov_b64 exec, -1 + + .v_clear_nc v_c, 64 + + s_waitcnt lgkmcnt(0) + s_barrier + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + s_cmp_gt_i32 s[s_kitr], 0 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body: + ; accumulate im + + ; a buffer x + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_1 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_1: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + ;s_waitcnt vmcnt(0) + .v_clear_nc v_ay, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ay+ 0:v_ay+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ay+ 2:v_ay+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ay+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ay+ 4:v_ay+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ay+ 6:v_ay+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ax + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ax + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ax + 6, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ax + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ax + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ax + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ax + 7, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_1 + + ; a buffer y + ;--- start move slice window + s_add_u32 s[s_move_slice_k_ix], 1, s[s_move_slice_k_ix] + s_cmp_le_u32 s[s_x], s[s_move_slice_k_ix] + s_cselect_b32 s[s_tmp], s[s_dilation_w_x], s[s_dilation_w] + s_cselect_b32 s[s_tmp+1], s[s_in_diff_hi], s[s_in_diff_wi] + v_add_nc_u32 v[v_in_iwi+0], s[s_tmp], v[v_in_iwi+0] + v_add_nc_u32 v[v_in_iwi+1], s[s_tmp], v[v_in_iwi+1] + v_add_nc_u32 v[v_in_iwi+2], s[s_tmp], v[v_in_iwi+2] + v_add_nc_u32 v[v_in_iwi+3], s[s_tmp], v[v_in_iwi+3] + v_add_nc_u32 v[v_in_os+0], s[s_tmp+1], v[v_in_os+0] + v_add_nc_u32 v[v_in_os+1], s[s_tmp+1], v[v_in_os+1] + v_add_nc_u32 v[v_in_os+2], s[s_tmp+1], v[v_in_os+2] + v_add_nc_u32 v[v_in_os+3], s[s_tmp+1], v[v_in_os+3] + s_cbranch_scc0 igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_2 + s_mov_b32 s[s_move_slice_k_ix], 0 + v_add_nc_i32 v[v_in_ihi+0], s[s_dilation_h], v[v_in_ihi+0] + v_add_nc_i32 v[v_in_ihi+1], s[s_dilation_h], v[v_in_ihi+1] + v_add_nc_i32 v[v_in_ihi+2], s[s_dilation_h], v[v_in_ihi+2] + v_add_nc_i32 v[v_in_ihi+3], s[s_dilation_h], v[v_in_ihi+3] +igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_acc_yx_x_end_2: + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+0] + v_cndmask_b32 v[v_in_flag+0], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+0] + v_cndmask_b32 v[v_in_flag+0], 0, v[v_in_flag+0] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + ;--- end move slice window + + .v_clear_nc v_ax, 4 + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os+0], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + .v_clear_nc v_ax+4, 4 + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_waitcnt vmcnt(4) lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay + 6, v_b + 0 + + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay + 7, v_b +16 + + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + + s_sub_i32 s[s_kitr], s[s_kitr], 8 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + s_cmp_gt_i32 s[s_kitr], 0 + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end: + s_waitcnt vmcnt(0) + + v_mov_b32 v[v_ay + 0], v[v_ax + 0] + v_mov_b32 v[v_ay + 1], v[v_ax + 1] + v_mov_b32 v[v_ay + 2], v[v_ax + 2] + v_mov_b32 v[v_ay + 3], v[v_ax + 3] + v_mov_b32 v[v_ay + 4], v[v_ax + 4] + v_mov_b32 v[v_ay + 5], v[v_ax + 5] + v_mov_b32 v[v_ay + 6], v[v_ax + 6] + v_mov_b32 v[v_ay + 7], v[v_ax + 7] + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_1: + s_waitcnt vmcnt(0) + + s_sub_i32 s[s_batch_m], s[s_batch_m], 1 + v_add_nc_u32 v[v_ib], s[s_stride_m], v[v_ib] + + s_cmp_gt_i32 s[s_batch_m], 0 + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_not_load_next + ; --- start move slice for batch m + ; ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + ; iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + ; we will update v_in_os below, so use this as v_tmp + .mdiv_u32_rem_vs v_in_iwi,v_in_ihi,v_ib,s_magic_1,s_shift_m1,s_wo,v_in_os + v_mul_u32_u24 v[v_in_ihi], s[s_stride_h], v[v_in_ihi] + .v_clear_nc v_ax, 2 + v_add_nc_u32 v[v_in_flag+1], s[s_ib_stride], v[v_ib] + v_sub_nc_i32 v[v_in_ihi], v[v_in_ihi], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi], s[s_stride_w], v[v_in_iwi] + .v_clear_nc v_ax+2, 2 + v_sub_nc_i32 v[v_in_iwi], v[v_in_iwi], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+1,v_in_ihi+1,v_in_flag+1,s_magic_1,s_shift_m1,s_wo,v_in_os+1 + + v_mul_u32_u24 v[v_in_os], s[s_wi], v[v_in_ihi] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi] + v_cndmask_b32 v[v_in_flag], 0, 1 + v_add_nc_u32 v[v_in_os], v[v_in_iwi], v[v_in_os] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi] + v_cndmask_b32 v[v_in_flag], 0, v[v_in_flag] + v_mul_lo_u32 v[v_in_os], s[s_in_stride_wi], v[v_in_os] + + v_mul_u32_u24 v[v_in_ihi+1], s[s_stride_h], v[v_in_ihi+1] + v_sub_nc_i32 v[v_in_ihi+1], v[v_in_ihi+1], s[s_pad_h] + v_mul_u32_u24 v[v_in_iwi+1], s[s_stride_w], v[v_in_iwi+1] + v_sub_nc_i32 v[v_in_iwi+1], v[v_in_iwi+1], s[s_pad_w] + + v_add_nc_u32 v[v_in_flag+2], s[s_ib_stride], v[v_in_flag+1] + + v_cmpx_le_u32 1, v[v_in_flag+0] + global_load_dwordx2 v[v_ax+ 0:v_ax+ 1], v[v_in_os], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_u32_u24 v[v_in_os+1], s[s_wi], v[v_in_ihi+1] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+1] + v_cndmask_b32 v[v_in_flag+1], 0, 1 + v_add_nc_u32 v[v_in_os+1], v[v_in_iwi+1], v[v_in_os+1] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+1] + v_cndmask_b32 v[v_in_flag+1], 0, v[v_in_flag+1] + v_mul_lo_u32 v[v_in_os+1], s[s_in_stride_wi], v[v_in_os+1] + + v_cmpx_le_u32 1, v[v_in_flag+1] + global_load_dwordx2 v[v_ax+ 2:v_ax+ 3], v[v_in_os+1], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + .mdiv_u32_rem_vs v_in_iwi+2,v_in_ihi+2,v_in_flag+2,s_magic_1,s_shift_m1,s_wo,v_in_os+2 + v_add_nc_u32 v[v_in_flag+3], s[s_ib_stride], v[v_in_flag+2] + v_mul_lo_u32 v[v_in_ihi+2], s[s_stride_h], v[v_in_ihi+2] + .v_clear_nc v_ax+4, 2 + v_sub_nc_i32 v[v_in_ihi+2], v[v_in_ihi+2], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+2], s[s_stride_w], v[v_in_iwi+2] + .v_clear_nc v_ax+6, 2 + v_sub_nc_i32 v[v_in_iwi+2], v[v_in_iwi+2], s[s_pad_w] + + .mdiv_u32_rem_vs v_in_iwi+3,v_in_ihi+3,v_in_flag+3,s_magic_1,s_shift_m1,s_wo,v_in_os+3 + v_mul_lo_u32 v[v_in_ihi+3], s[s_stride_h], v[v_in_ihi+3] + v_sub_nc_i32 v[v_in_ihi+3], v[v_in_ihi+3], s[s_pad_h] + v_mul_lo_u32 v[v_in_iwi+3], s[s_stride_w], v[v_in_iwi+3] + v_sub_nc_i32 v[v_in_iwi+3], v[v_in_iwi+3], s[s_pad_w] + + v_mul_lo_u32 v[v_in_os+2], s[s_wi], v[v_in_ihi+2] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+2] + v_cndmask_b32 v[v_in_flag+2], 0, 1 + v_add_nc_u32 v[v_in_os+2], v[v_in_iwi+2], v[v_in_os+2] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+2] + v_cndmask_b32 v[v_in_flag+2], 0, v[v_in_flag+2] + v_mul_lo_u32 v[v_in_os+2], s[s_in_stride_wi], v[v_in_os+2] + + v_cmpx_le_u32 1, v[v_in_flag+2] + global_load_dwordx2 v[v_ax+ 4:v_ax+ 5], v[v_in_os+2], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + v_mul_lo_u32 v[v_in_os+3], s[s_wi], v[v_in_ihi+3] + v_cmp_gt_u32 s[s_hi], v[v_in_ihi+3] + v_cndmask_b32 v[v_in_flag+3], 0, 1 + v_add_nc_u32 v[v_in_os+3], v[v_in_iwi+3], v[v_in_os+3] + v_cmp_gt_u32 s[s_wi], v[v_in_iwi+3] + v_cndmask_b32 v[v_in_flag+3], 0, v[v_in_flag+3] + v_mul_lo_u32 v[v_in_os+3], s[s_in_stride_wi], v[v_in_os+3] + + v_cmpx_le_u32 1, v[v_in_flag+3] + global_load_dwordx2 v[v_ax+ 6:v_ax+ 7], v[v_in_os+3], s[s_p_in:s_p_in+1] + s_mov_b64 exec, -1 + + s_mov_b32 s[s_move_slice_k_ix], 0 + +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end_not_load_next: + ; --- end move slice for batch m + + s_waitcnt lgkmcnt(4) + .fma_1x16_int8x4 v_c+ 0, v_ay + 0, v_b + 0 + .fma_1x16_int8x4 v_c+16, v_ay + 2, v_b + 0 + .fma_1x16_int8x4 v_c+32, v_ay + 4, v_b + 0 + .fma_1x16_int8x4 v_c+48, v_ay + 6, v_b + 0 + + s_waitcnt lgkmcnt(0) + .fma_1x16_int8x4 v_c+ 0, v_ay + 1, v_b +16 + .fma_1x16_int8x4 v_c+16, v_ay + 3, v_b +16 + .fma_1x16_int8x4 v_c+32, v_ay + 5, v_b +16 + .fma_1x16_int8x4 v_c+48, v_ay + 7, v_b +16 + + v_mov_b32 v[v_sld_b_os], 0 ; reset to start + + .pack_i8x4_i32_r4 v_c_buf+ 0, v_c+ 0, s_0xff + .pack_i8x4_i32_r4 v_c_buf+ 4, v_c+16, s_0xff + v_cmpx_le_u32 1, v[v_out_flag] + global_store_dwordx4 v[v_out_os], v[v_c_buf+0:v_c_buf+3], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+1] + global_store_dwordx4 v[v_out_os+1], v[v_c_buf+ 4:v_c_buf+ 7], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + .pack_i8x4_i32_r4 v_c_buf+ 8, v_c+32, s_0xff + .pack_i8x4_i32_r4 v_c_buf+12, v_c+48, s_0xff + + v_cmpx_le_u32 1, v[v_out_flag+2] + global_store_dwordx4 v[v_out_os+2], v[v_c_buf+ 8:v_c_buf+11], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + v_cmpx_le_u32 1, v[v_out_flag+3] + global_store_dwordx4 v[v_out_os+3], v[v_c_buf+12:v_c_buf+15], s[s_p_out:s_p_out+1] + s_mov_b64 exec, -1 + + s_cmp_le_i32 s[s_batch_m], 0 + + s_cbranch_scc1 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_end + ds_read_b128 v[v_b+ 0:v_b+ 3], v[v_sld_b_os], offset:k_n_dword*4*0 + 0*4 + ds_read_b128 v[v_b+ 4:v_b+ 7], v[v_sld_b_os], offset:k_n_dword*4*0 + 4*4 + ds_read_b128 v[v_b+ 8:v_b+11], v[v_sld_b_os], offset:k_n_dword*4*0 + 8*4 + ds_read_b128 v[v_b+12:v_b+15], v[v_sld_b_os], offset:k_n_dword*4*0 +12*4 + ds_read_b128 v[v_b+16:v_b+19], v[v_sld_b_os], offset:k_n_dword*4*1 + 0*4 + ds_read_b128 v[v_b+20:v_b+23], v[v_sld_b_os], offset:k_n_dword*4*1 + 4*4 + ds_read_b128 v[v_b+24:v_b+27], v[v_sld_b_os], offset:k_n_dword*4*1 + 8*4 + ds_read_b128 v[v_b+28:v_b+31], v[v_sld_b_os], offset:k_n_dword*4*1 +12*4 + + .v_clear_nc v_c, 64 + v_add_nc_u32 v[v_sld_b_os], s[s_sld_b_stride], v[v_sld_b_os] ; accumulate sld_b_os + + v_add_nc_u32 v[v_out_os], s[s_out_stride], v[v_out_os] + s_sub_i32 s[s_kitr], s[s_wei_stride_k], 8 + v_add_nc_u32 v[v_out_os+1], s[s_out_stride], v[v_out_os+1] + v_add_nc_u32 v[v_out_os+2], s[s_out_stride], v[v_out_os+2] + v_add_nc_u32 v[v_out_os+3], s[s_out_stride], v[v_out_os+3] + + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os] + v_cndmask_b32 v[v_out_flag], 0, 1 + s_cmp_gt_i32 s[s_kitr], 0 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+1] + v_cndmask_b32 v[v_out_flag+1], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+2] + v_cndmask_b32 v[v_out_flag+2], 0, 1 + v_cmp_gt_u32 s[s_out_stride_n], v[v_out_os+3] + v_cndmask_b32 v[v_out_flag+3], 0, 1 + + s_cbranch_scc0 L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_end + s_branch L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_fma_body +L_igemm_fwd_btm_nhwc_int8_512x16x8_r2_end: + s_endpgm + +; LDS: 2 * 4 * 4 * 128 +; r2 4dword 4 threads +.rodata +.p2align 6 +.amdhsa_kernel igemm_fwd_btm_nhwc_int8_512x16x8_r2 + .amdhsa_group_segment_fixed_size 4096 + .amdhsa_user_sgpr_kernarg_segment_ptr 1 + .amdhsa_system_sgpr_workgroup_id_x 1 + .amdhsa_system_sgpr_workgroup_id_y 1 + .amdhsa_system_sgpr_workgroup_id_z 1 + .amdhsa_system_vgpr_workitem_id 0 + .amdhsa_next_free_vgpr 140 + .amdhsa_next_free_sgpr 58 + .amdhsa_ieee_mode 0 + .amdhsa_dx10_clamp 0 + .amdhsa_wavefront_size32 1 + .amdhsa_workgroup_processor_mode 0 +.end_amdhsa_kernel diff --git a/test/inference/test_inference.cpp b/test/inference/test_inference.cpp index a6bf902e..17a5b067 100644 --- a/test/inference/test_inference.cpp +++ b/test/inference/test_inference.cpp @@ -17,12 +17,9 @@ #include #include "args.h" -#define USE_HALF_HPP -#ifdef USE_HALF_HPP #include "half.hpp" using float16 = half_float::half; -#endif std::string parse_base_arg(int argc, char* argv[]) { @@ -34,7 +31,7 @@ std::string parse_base_arg(int argc, char* argv[]) std::string arg = argv[1]; - if(arg != "conv" && arg != "convfp16" && arg != "--version") + if(arg != "conv" && arg != "convfp16" && arg != "convint8" && arg != "--version") { printf("Invalid Base Input Argument\n"); exit(0); @@ -89,10 +86,25 @@ typedef struct { typedef enum { driverHalf = 0, /*!< 16-bit floating point (Fully supported) */ driverFloat = 1, /*!< 32-bit floating point (Fully supported) */ + driverInt8 = 3, driverBFloat16 = 5, /*!< 16-bit binary floating point (8-bit exponent, 7-bit fraction) (Partially supported) */ } driverDataType_t; +static inline size_t get_data_byte(driverDataType_t dtype) +{ + if(dtype == driverHalf) + return 2; + if(dtype == driverFloat) + return 4; + if(dtype == driverInt8) + return 1; + if(dtype == driverBFloat16) + return 2; + assert(0); + return 0; +} + static inline int env_get_int(const char *var_name, int default_int) { char *v = getenv(var_name); int r = default_int; @@ -206,15 +218,24 @@ void gen_rand_vector(Dst_T *vec, size_t vec_size, Src_T fmin, Src_T fmax, Src_T th.join(); } -static inline bool valid_float(float p) +template +bool valid_float(T p) { return !(std::isnan(p) || std::isinf(p)); } + +template<> +bool valid_float(int8_t p) +{ + // there is no meaning to valid integer number + return true; +} + #ifndef ABS #define ABS(b) ((b) > 0 ? (b) : -1 * (b)) #endif template -static inline bool valid_vector(const float *ref, const T *pred, size_t n, +bool valid_vector(const float *ref, const T *pred, size_t n, double nrms = 1e-6) { double s0 = 0.0; double s1 = 0.0; @@ -225,7 +246,7 @@ static inline bool valid_vector(const float *ref, const T *pred, size_t n, for (size_t i = 0; i < n; ++i) { double ri = (double)ref[i]; double pi = (double)pred[i]; - if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ printf(" invalid float at %4zu, ref:%f, pred:%f\n", i, ri, pi); return false; } @@ -236,12 +257,12 @@ static inline bool valid_vector(const float *ref, const T *pred, size_t n, s1 += rr; if(igemm_per_pixel_check){ double delta = ABS(ABS(ri - pi) / ri); - printf("[%zu] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, ((uint32_t *)pred)[i], delta > 3e-5? "N":"Y"); + printf("[%zu] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, *(uint32_t*)(&pred[i]), delta > 3e-5? "N":"Y"); if (delta > 3e-5) { if(igemm_per_pixel_check_print){ if (pp_err < 100) printf("diff at %zu, ref:%lf, pred:%lf(0x%08x), d:%lf\n", i, ri, - pi, ((uint32_t *)pred)[i], delta); + pi, *(uint32_t*)(&pred[i]), delta); } pp_err++; } @@ -255,6 +276,36 @@ static inline bool valid_vector(const float *ref, const T *pred, size_t n, ; } +template<> +bool valid_vector(const float *ref, const int8_t *pred, size_t n, + double nrms) { + // int8 valid, we prefer a per pixel match + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + size_t pp_err = 0; + + for (size_t i = 0; i < n; ++i) { + if(!(valid_float(ref[i]) ) ){ + printf(" invalid float at %4zu, ref:%f\n", i, ref[i]); + return false; + } + int8_t pi = pred[i]; + int32_t ri = static_cast(ref[i]); + int8_t ri_clamp; + memcpy(&ri_clamp, &ri, 1); + + if(igemm_per_pixel_check){ + printf("[%zu] ref:%d(%d), pred:%d(0x%08x) [%s]\n", i, ri, ri_clamp, pi, + *(uint32_t*)(&pred[i]), pi != ri_clamp ? "N":"Y"); + } + + if(pi != ri_clamp){ + pp_err++; + } + } + return pp_err == 0; +} + static inline void dump_output_dword(const float *out, size_t n) { for (size_t i = 0; i < n; ++i) { @@ -334,24 +385,39 @@ int main(int argc, char **argv){ int repeat = env_get_int("REPEAT", REPEAT); int sclk_mhz = env_get_int("SCLK_MHZ", SCLK_MHZ); int dump_out = env_get_int("DUMP_OUT", 0); - char *hsaco = env_get_str("HSACO", HSACO); + char *gpu_naive_conv_hsaco = env_get_str("GPU_NAIVE_CONV_HSACO", GPU_NAIVE_CONV_HSACO); gpu_naive_conv_init(gpu_naive_conv_hsaco); std::string base_arg = parse_base_arg(argc, argv); + std::string default_hsaco = "igemm_fwd_btm_nhwc_"; driverDataType_t driver_data_type; - if(base_arg == "conv") + int fp_factor = 1; + if(base_arg == "conv"){ driver_data_type = driverFloat; - else if(base_arg == "convfp16") + default_hsaco += "fp32.hsaco"; + } + else if(base_arg == "convfp16"){ driver_data_type = driverHalf; + default_hsaco += "fp16.hsaco"; + fp_factor = 2; + } else if(base_arg == "convbf16") { driver_data_type = driverBFloat16; exit(0); } + else if(base_arg == "convint8") { + driver_data_type = driverInt8; + default_hsaco += "int8.hsaco"; + fp_factor = 4; + } else exit(0); + size_t data_byte = get_data_byte(driver_data_type); + char *hsaco = env_get_str("HSACO", const_cast(default_hsaco.c_str())); + hipModule_t module; HIP_CALL(hipModuleLoad(&module, hsaco)); @@ -394,20 +460,18 @@ int main(int argc, char **argv){ HIP_CALL(hipMalloc(&device_weight, static_cast(k) * c * y * x * sizeof(float))); HIP_CALL(hipMalloc(&device_output, static_cast(n) * k * ho * wo * sizeof(float))); -#ifdef USE_HALF_HPP - // fp16 type - float16 *host_input_f16 = (float16 *)malloc(n * c * hi * wi * sizeof(float16)); - float16 *host_weight_f16 = (float16 *)malloc(k * c * y * x * sizeof(float16)); - float16 *host_output_f16 = (float16 *)malloc(n * k * ho * wo * sizeof(float16)); - float16 *device_input_f16; - float16 *device_weight_f16; - float16 *device_output_f16; + void *host_input_dtype = malloc(n * c * hi * wi * data_byte); + void *host_weight_dtype = malloc(k * c * y * x * data_byte); + void *host_output_dtype = malloc(n * k * ho * wo * data_byte); - HIP_CALL(hipMalloc(&device_input_f16, n * c * hi * wi * sizeof(float16))); - HIP_CALL(hipMalloc(&device_weight_f16, k * c * y * x * sizeof(float16))); - HIP_CALL(hipMalloc(&device_output_f16, n * k * ho * wo * sizeof(float16))); -#endif + void *device_input_dtype; + void *device_weight_dtype; + void *device_output_dtype; + + HIP_CALL(hipMalloc(&device_input_dtype, n * c * hi * wi * data_byte)); + HIP_CALL(hipMalloc(&device_weight_dtype, k * c * y * x * data_byte)); + HIP_CALL(hipMalloc(&device_output_dtype, n * k * ho * wo * data_byte)); int need_verify = conv_args.get_int("verify"); @@ -438,13 +502,13 @@ int main(int argc, char **argv){ #endif } - double theo_gflops = theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * 2/*fp16, 2x speed*/); + double theo_gflops = theoritical_gflops(((double)sclk_mhz) / 1000.0, num_cu, num_simd * fp_factor); double nrms = get_nrms(forw, driver_data_type); printf("num_cu:%d, gcn_arch:%d, theo_gflops:%f\n", num_cu, gcn_arch, theo_gflops); if (need_fwd){ - float *device_output_to_host = NULL; + void *device_output_to_host = NULL; if (need_verify) { // gen rand //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 0.0, 1.0); @@ -454,13 +518,16 @@ int main(int argc, char **argv){ //gen_rand_vector(host_input, static_cast(n) * c * hi * wi, 1, 1); //gen_rand_vector(host_weight, static_cast(k) * c * y * x, 1, 1); -#ifdef USE_HALF_HPP if(driver_data_type == driverHalf){ // move to different data type - tensor_copy(host_input_f16, host_input, static_cast(n) * c * hi * wi); - tensor_copy(host_weight_f16, host_weight, static_cast(k) * c * y * x); + tensor_copy(static_cast(host_input_dtype), host_input, static_cast(n) * c * hi * wi); + tensor_copy(static_cast(host_weight_dtype), host_weight, static_cast(k) * c * y * x); + } + else if(driver_data_type == driverInt8){ + // move to different data type + tensor_copy(static_cast(host_input_dtype), host_input, static_cast(n) * c * hi * wi); + tensor_copy(static_cast(host_weight_dtype), host_weight, static_cast(k) * c * y * x); } -#endif HIP_CALL(hipMemcpy(device_input, host_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); @@ -476,44 +543,53 @@ int main(int argc, char **argv){ static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); - if(driver_data_type == driverHalf){ -#ifdef USE_HALF_HPP - device_output_to_host = (float *)malloc((static_cast(n) * k * ho * wo * sizeof(float16) + 3) / 4 * 4); -#endif + if(driver_data_type == driverHalf || driver_data_type == driverInt8){ + device_output_to_host = malloc((static_cast(n) * k * ho * wo * data_byte + 3) / 4 * 4); } else{ - device_output_to_host = (float *)malloc(static_cast(n) * k * ho * wo * sizeof(float)); + device_output_to_host = malloc(static_cast(n) * k * ho * wo * sizeof(float)); } } + if(driver_data_type == driverFloat){ HIP_CALL(hipMemcpy(device_input, host_input, static_cast(n) * c * hi * wi * sizeof(float), hipMemcpyHostToDevice)); HIP_CALL(hipMemcpy(device_weight, host_weight, static_cast(k) * c * y * x * sizeof(float), hipMemcpyHostToDevice)); + }else{ + HIP_CALL(hipMemcpy(device_input_dtype, host_input_dtype, + static_cast(n) * c * hi * wi * data_byte, hipMemcpyHostToDevice)); + HIP_CALL(hipMemcpy(device_weight_dtype, host_weight_dtype, + static_cast(k) * c * y * x * data_byte, hipMemcpyHostToDevice)); } -#ifdef USE_HALF_HPP - else if(driver_data_type == driverHalf){ - HIP_CALL(hipMemcpy(device_input_f16, host_input_f16, - static_cast(n) * c * hi * wi * sizeof(float16), hipMemcpyHostToDevice)); - HIP_CALL(hipMemcpy(device_weight_f16, host_weight_f16, - static_cast(k) * c * y * x * sizeof(float16), hipMemcpyHostToDevice)); - } -#endif + igemm_fwd_btm_t conv_fwd_driver; + int valid_index = 0; for (int i = 0; i < sizeof(igemm_fwd_btm_kernel_list)/sizeof(igemm_fwd_btm_kernel_list[0]); i++) { igemm_fwd_btm_kernel_info_t *kinfo = &igemm_fwd_btm_kernel_list[i]; + if(driver_data_type == driverHalf){ + if(kinfo->data_type != "fp16") + continue; + } + else if(driver_data_type == driverInt8){ + if(kinfo->data_type != "int8") + continue; + } - - printf("[fwd:%2d] %s, ", i, conv_fwd_driver.get_kernel_name(kinfo).c_str()); + printf("[fwd:%2d] %s, ", valid_index, conv_fwd_driver.get_kernel_name(kinfo).c_str()); fflush(stdout); result_t result; -#ifdef USE_HALF_HPP - result = conv_fwd_driver.run(&conv_args, module, kinfo, device_input_f16, - device_weight_f16, device_output_f16, warmup, repeat, driver_data_type); -#endif + result = conv_fwd_driver.run(&conv_args, module, kinfo, device_input_dtype, + device_weight_dtype, device_output_dtype, warmup, repeat, driver_data_type); + valid_index++; + + if (result.return_code != 0){ + printf("not applicatble\n"); + continue; + } double gflops = measured_conv_gflops( result.duration_ms, n, c, hi, wi, k, y, x, stride_h, stride_w, @@ -526,21 +602,22 @@ int main(int argc, char **argv){ HIP_CALL(hipMemcpy(device_output_to_host, device_output, static_cast(n) * k * ho * wo * sizeof(float), hipMemcpyDeviceToHost)); - is_valid = valid_vector(host_output, device_output_to_host, + is_valid = valid_vector(host_output, static_cast(device_output_to_host), static_cast(n) * k * ho * wo, nrms); } -#ifdef USE_HALF_HPP - else if(driver_data_type == driverHalf) { - HIP_CALL(hipMemcpy(device_output_to_host, device_output_f16, - static_cast(n) * k * ho * wo * sizeof(float16), + else if(driver_data_type == driverHalf || driver_data_type == driverInt8) { + HIP_CALL(hipMemcpy(device_output_to_host, device_output_dtype, + static_cast(n) * k * ho * wo * data_byte, hipMemcpyDeviceToHost)); if(dump_out) - dump_output_dword(device_output_to_host, static_cast(n) * k * ho * wo / 2); - float16 *device_output_to_host_fp16 = (float16 *)device_output_to_host; - is_valid = valid_vector(host_output, device_output_to_host_fp16, + dump_output_dword(static_cast(device_output_to_host), static_cast(n) * k * ho * wo / fp_factor); + if(driver_data_type == driverHalf) + is_valid = valid_vector(host_output, static_cast(device_output_to_host), + static_cast(n) * k * ho * wo, nrms); + else if(driver_data_type == driverInt8) + is_valid = valid_vector(host_output, static_cast(device_output_to_host), static_cast(n) * k * ho * wo, nrms); } -#endif printf(", valid:%s", is_valid ? "y" : "n"); } printf("\n"); From ae7cef3733c860f601e58a12045cdcd87c09627d Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 23 Mar 2021 15:54:38 +0800 Subject: [PATCH 40/40] ignore print in wrw driver --- driver/igemm_wrw_gtc_driver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/driver/igemm_wrw_gtc_driver.h b/driver/igemm_wrw_gtc_driver.h index 7bbd6bac..22010de1 100644 --- a/driver/igemm_wrw_gtc_driver.h +++ b/driver/igemm_wrw_gtc_driver.h @@ -568,7 +568,7 @@ class igemm_wrw_gtc_t : public igemm_driver_base_t { if (!tunable_is_valid(arg, tunable)) { result_t result; result.return_code = -1; - std::cout << "not valid tunable config." << std::endl; + // std::cout << "not valid tunable config." << std::endl; return result; }