Skip to content

Commit c309694

Browse files
author
ssjia
committed
[ET-VK] QConv: Avoid dynamic indexing of temporary arrays
## Context `conv2d_q8ta_q8csw_q8to_conv2d` and `conv2d_dw_q8ta_q8csw_q8to_conv2d` shaders currently have EXTREMELY slow latency on Arm GPUs (Mali/Immortalis architecture). Conversely, the `conv2d_q8ta_q8csw_q8to_linear_tiled` shaders, which are used for pointwise convolutions and 3x3 non-grouped convolutions have reasonable latency. Using `malioc` (Mali Offline Compiler) it was found that the slow shaders reported usage of the "stack", whereas the fast shader did not. According to the docs of the Mali Offline Compiler: ### Stack use Stack is a form of thread local storage that is used by compiler-generated memory allocations and register spills. The stack size metric in the report shows the size of the stack memory for a single shader thread. Later compilers generate additional sub-metrics that show the split between regions used for spilling and compiler-generated allocations. You can reduce the size of your stack in the following ways: * Avoid coding patterns that require the compiler to allocate on the stack, such as dynamically indexing into temporary arrays. * Reduce register pressure to avoid stack spilling. What the two slow shaders were doing that the fast shader was not, was loading the input convolution window into a local array, then performing the calculation once the entire inpur window was loaded. This caused two issues: 1. A large array would be needed to store the input window, which requires a lot of registers 2. Performing the convolution calculation with the input window array requires dynamic indexing Presumably, this was causing a lot of memory to be allocated via the stack, causing the performance regressions. ## Changes Rewrite the `conv2d_q8ta_q8csw_q8to_conv2d` shader to not pre-load input values into a local array, and update the convolution calculation to not use dynamic indexing. Differential Revision: [D88702990](https://our.internmc.facebook.com/intern/diff/D88702990/) ghstack-source-id: 327992212 Pull Request resolved: #16142
1 parent 2ee4013 commit c309694

File tree

3 files changed

+96
-19
lines changed

3 files changed

+96
-19
lines changed

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#define align_up_4(x) ((x + 3) & -4)
3030
#define align_up_8(x) ((x + 7) & -8)
3131

32+
#define align_down_4(x) ((x) & -4)
33+
3234
#define mod_2(x) ((x) & 1)
3335
#define mod_4(x) ((x) & 3)
3436
#define mod_8(x) ((x) & 7)

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8_utils.glslh

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,24 @@ bool in_bounds(
4444
return true;
4545
}
4646

47+
ivec4 load_input_block(
48+
const int in_x4,
49+
const int in_y,
50+
const int ic4,
51+
const Conv2dBlockExtents in_block_extents,
52+
const ivec4 input_zps) {
53+
if (!in_bounds(in_x4, in_y, ic4, in_block_extents)) {
54+
return input_zps;
55+
}
56+
#ifdef PACKED_INT8_INPUT_BUFFER
57+
const int buffer_idx =
58+
in_y * in_block_extents.data_xz + in_x4 * in_block_extents.data.z + ic4;
59+
return t_packed_int8_input[buffer_idx];
60+
#else
61+
return texelFetch(t_packed_int8_input, ivec3(in_x4, in_y, ic4), 0);
62+
#endif
63+
}
64+
4765
Int8InputWindow1D load_input_window(
4866
const int w_start,
4967
const int w_end,
@@ -100,6 +118,34 @@ ivec4 load_weight_block(
100118
#endif
101119
}
102120

121+
void conv1d_accumulate(
122+
inout Int32Accum accum,
123+
const ivec4 in_block,
124+
const ivec4 weight_block,
125+
const int kx,
126+
const int out_x_start,
127+
const int in_x_start) {
128+
[[unroll]] for (int out_x = 0; out_x < 4; ++out_x) {
129+
int in_x_offset = (out_x_start + out_x) * conv2d_params.stride.x
130+
- conv2d_params.padding.x
131+
+ (kx * conv2d_params.dilation.x);
132+
in_x_offset -= in_x_start;
133+
134+
const bool in_bounds = in_x_offset >= 0 && in_x_offset < 4;
135+
136+
[[unroll]] for (int oc = 0; oc < 4; ++oc) {
137+
int updated = accum.data[out_x][0][oc];
138+
if (in_bounds) {
139+
updated = dotPacked4x8AccSatEXT(
140+
in_block[in_x_offset],
141+
weight_block[oc],
142+
updated);
143+
}
144+
accum.data[out_x][0][oc] = updated;
145+
}
146+
}
147+
}
148+
103149
void perform_conv1d(
104150
inout Int32Accum accum,
105151
const Int8InputWindow1D input_window,
@@ -146,6 +192,20 @@ void printWeightBlock(const ivec4 weight_block) {
146192
}
147193
}
148194

195+
void printInputBlock(const ivec4 input_block) {
196+
debugPrintfEXT("InputBlock contents: \\n");
197+
for (int i = 0; i < 4; ++i) {
198+
ivec4 unpacked = unpack_int8x4(input_block[i]);
199+
debugPrintfEXT(
200+
" [%d]: (%d, %d, %d, %d) \\n",
201+
i,
202+
unpacked.x,
203+
unpacked.y,
204+
unpacked.z,
205+
unpacked.w);
206+
}
207+
}
208+
149209
#endif // DEBUG_MODE
150210

151211
#endif // CONV2D_Q8_UTILS_GLSLH

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.glsl

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,7 @@ void main() {
8282
return;
8383
}
8484

85-
const int out_w = mul_4(out_block_idx.data.x);
86-
const int w_start =
87-
(out_w * conv2d_params.stride.x) - conv2d_params.padding.x;
88-
const int w_end = ((out_w + 3) * conv2d_params.stride.x) -
89-
conv2d_params.padding.x +
90-
(conv2d_params.kernel_size.x - 1) * conv2d_params.dilation.x;
85+
const int out_x_start = mul_4(out_block_idx.data.x);
9186

9287
Conv2dBlockExtents in_block_extents = make_block_extents(input_sizes);
9388

@@ -99,24 +94,29 @@ void main() {
9994

10095
const int IC4_per_group = div_up_4(conv2d_params.in_channels_per_group);
10196

102-
const int n = mul_4(out_block_idx.data.z);
103-
const int group_idx = n / conv2d_params.out_channels_per_group;
97+
const int out_z = mul_4(out_block_idx.data.z);
98+
const int group_idx = out_z / conv2d_params.out_channels_per_group;
10499
const int group_ic4_offset = group_idx * IC4_per_group;
105100

106101
for (int ky = 0; ky < conv2d_params.kernel_size.y; ky++) {
107-
const int h = out_block_idx.data.y * conv2d_params.stride.y -
102+
const int in_y = out_block_idx.data.y * conv2d_params.stride.y -
108103
conv2d_params.padding.y + ky * conv2d_params.dilation.y;
109104

110-
for (int ic4 = 0; ic4 < IC4_per_group; ic4++) {
111-
Int8InputWindow1D int8_input_window = load_input_window(
112-
w_start,
113-
w_end,
114-
h,
115-
group_ic4_offset + ic4,
116-
in_block_extents,
117-
input_zps);
105+
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
106+
int in_x_load_start =
107+
(out_x_start * conv2d_params.stride.x)
108+
- conv2d_params.padding.x
109+
+ (kx * conv2d_params.dilation.x);
118110

119-
for (int kx = 0; kx < conv2d_params.kernel_size.x; kx++) {
111+
int in_x_load_end =
112+
((out_x_start + 3) * conv2d_params.stride.x)
113+
- conv2d_params.padding.x
114+
+ (kx * conv2d_params.dilation.x);
115+
116+
in_x_load_start = align_down_4(in_x_load_start);
117+
in_x_load_end = align_down_4(in_x_load_end);
118+
119+
for (int ic4 = 0; ic4 < IC4_per_group; ic4++) {
120120
const ivec4 weight_block = load_weight_block(
121121
ic4,
122122
kx,
@@ -127,7 +127,22 @@ void main() {
127127
conv2d_params.kernel_size.y,
128128
out_block_extents.data.z);
129129

130-
perform_conv1d(out_accum, int8_input_window, weight_block, kx);
130+
for (int in_x = in_x_load_start; in_x <= in_x_load_end; in_x+=4) {
131+
const ivec4 in_block = load_input_block(
132+
div_4(in_x),
133+
in_y,
134+
group_ic4_offset + ic4,
135+
in_block_extents,
136+
input_zps);
137+
138+
conv1d_accumulate(
139+
out_accum,
140+
in_block,
141+
weight_block,
142+
kx,
143+
out_x_start,
144+
in_x);
145+
}
131146
}
132147
}
133148
}

0 commit comments

Comments
 (0)