From faa3a6b544c1148a99df607e8ddb17e69e69dabe Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Dec 2025 18:47:46 -0800 Subject: [PATCH] [ET-VK] QConv: Avoid dynamic indexing of temporary arrays ## Context `conv2d_q8ta_q8csw_q8to_conv2d` and `conv2d_dw_q8ta_q8csw_q8to_conv2d` shaders currently have EXTREMELY slow latency on Arm GPUs (Mali/Immortalis architecture). Conversely, the `conv2d_q8ta_q8csw_q8to_linear_tiled` shaders, which are used for pointwise convolutions and 3x3 non-grouped convolutions have reasonable latency. Using `malioc` (Mali Offline Compiler) it was found that the slow shaders reported usage of the "stack", whereas the fast shader did not. According to the docs of the Mali Offline Compiler: ### Stack use Stack is a form of thread local storage that is used by compiler-generated memory allocations and register spills. The stack size metric in the report shows the size of the stack memory for a single shader thread. Later compilers generate additional sub-metrics that show the split between regions used for spilling and compiler-generated allocations. You can reduce the size of your stack in the following ways: * Avoid coding patterns that require the compiler to allocate on the stack, such as dynamically indexing into temporary arrays. * Reduce register pressure to avoid stack spilling. What the two slow shaders were doing that the fast shader was not, was loading the input convolution window into a local array, then performing the calculation once the entire inpur window was loaded. This caused two issues: 1. A large array would be needed to store the input window, which requires a lot of registers 2. Performing the convolution calculation with the input window array requires dynamic indexing Presumably, this was causing a lot of memory to be allocated via the stack, causing the performance regressions. ## Changes Rewrite the `conv2d_q8ta_q8csw_q8to_conv2d` shader to not pre-load input values into a local array, and update the convolution calculation to not use dynamic indexing. Differential Revision: [D88702990](https://our.internmc.facebook.com/intern/diff/D88702990/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/common.glslh | 2 + .../graph/ops/glsl/conv2d_q8_utils.glslh | 60 +++++++++++++++++++ .../ops/glsl/conv2d_q8ta_q8csw_q8to.glsl | 53 ++++++++++------ 3 files changed, 96 insertions(+), 19 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 9ade64910f2..d6f0888d1fc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -29,6 +29,8 @@ #define align_up_4(x) ((x + 3) & -4) #define align_up_8(x) ((x + 7) & -8) +#define align_down_4(x) ((x) & -4) + #define mod_2(x) ((x) & 1) #define mod_4(x) ((x) & 3) #define mod_8(x) ((x) & 7) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh index 279f4f17f13..22e440f4948 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh @@ -44,6 +44,24 @@ bool in_bounds( return true; } +ivec4 load_input_block( + const int in_x4, + const int in_y, + const int ic4, + const Conv2dBlockExtents in_block_extents, + const ivec4 input_zps) { + if (!in_bounds(in_x4, in_y, ic4, in_block_extents)) { + return input_zps; + } +#ifdef PACKED_INT8_INPUT_BUFFER + const int buffer_idx = + in_y * in_block_extents.data_xz + in_x4 * in_block_extents.data.z + ic4; + return t_packed_int8_input[buffer_idx]; +#else + return texelFetch(t_packed_int8_input, ivec3(in_x4, in_y, ic4), 0); +#endif +} + Int8InputWindow1D load_input_window( const int w_start, const int w_end, @@ -100,6 +118,34 @@ ivec4 load_weight_block( #endif } +void conv1d_accumulate( + inout Int32Accum accum, + const ivec4 in_block, + const ivec4 weight_block, + const int kx, + const int out_x_start, + const int in_x_start) { + [[unroll]] for (int out_x = 0; out_x < 4; ++out_x) { + int in_x_offset = (out_x_start + out_x) * conv2d_params.stride.x + - conv2d_params.padding.x + + (kx * conv2d_params.dilation.x); + in_x_offset -= in_x_start; + + const bool in_bounds = in_x_offset >= 0 && in_x_offset < 4; + + [[unroll]] for (int oc = 0; oc < 4; ++oc) { + int updated = accum.data[out_x][0][oc]; + if (in_bounds) { + updated = dotPacked4x8AccSatEXT( + in_block[in_x_offset], + weight_block[oc], + updated); + } + accum.data[out_x][0][oc] = updated; + } + } +} + void perform_conv1d( inout Int32Accum accum, const Int8InputWindow1D input_window, @@ -146,6 +192,20 @@ void printWeightBlock(const ivec4 weight_block) { } } +void printInputBlock(const ivec4 input_block) { + debugPrintfEXT("InputBlock contents: \\n"); + for (int i = 0; i < 4; ++i) { + ivec4 unpacked = unpack_int8x4(input_block[i]); + debugPrintfEXT( + " [%d]: (%d, %d, %d, %d) \\n", + i, + unpacked.x, + unpacked.y, + unpacked.z, + unpacked.w); + } +} + #endif // DEBUG_MODE #endif // CONV2D_Q8_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl index 5839b13aeaa..ac05763129d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl @@ -82,12 +82,7 @@ void main() { return; } - const int out_w = mul_4(out_block_idx.data.x); - const int w_start = - (out_w * conv2d_params.stride.x) - conv2d_params.padding.x; - const int w_end = ((out_w + 3) * conv2d_params.stride.x) - - conv2d_params.padding.x + - (conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x; + const int out_x_start = mul_4(out_block_idx.data.x); Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes); @@ -99,24 +94,29 @@ void main() { const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group); - const int n = mul_4(out_block_idx.data.z); - const int group_idx = n / conv2d_params.out_channels_per_group; + const int out_z = mul_4(out_block_idx.data.z); + const int group_idx = out_z / conv2d_params.out_channels_per_group; const int group_ic4_offset = group_idx * IC4_per_group; for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) { - const int h = out_block_idx.data.y * conv2d_params.stride.y - + const int in_y = out_block_idx.data.y * conv2d_params.stride.y - conv2d_params.padding.y + ky * conv2d_params.dilation.y; - for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { - Int8InputWindow1D int8_input_window = load_input_window( - w_start, - w_end, - h, - group_ic4_offset + ic4, - in_block_extents, - input_zps); + for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + int in_x_load_start = + (out_x_start * conv2d_params.stride.x) + - conv2d_params.padding.x + + (kx * conv2d_params.dilation.x); - for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) { + int in_x_load_end = + ((out_x_start + 3) * conv2d_params.stride.x) + - conv2d_params.padding.x + + (kx * conv2d_params.dilation.x); + + in_x_load_start = align_down_4(in_x_load_start); + in_x_load_end = align_down_4(in_x_load_end); + + for (int ic4 = 0; ic4 < IC4_per_group; ic4++) { const ivec4 weight_block = load_weight_block( ic4, kx, @@ -127,7 +127,22 @@ void main() { conv2d_params.kernel_size.y, out_block_extents.data.z); - perform_conv1d(out_accum, int8_input_window, weight_block, kx); + for (int in_x = in_x_load_start; in_x <= in_x_load_end; in_x+=4) { + const ivec4 in_block = load_input_block( + div_4(in_x), + in_y, + group_ic4_offset + ic4, + in_block_extents, + input_zps); + + conv1d_accumulate( + out_accum, + in_block, + weight_block, + kx, + out_x_start, + in_x); + } } } }