Skip to content

vulkan: add TurboQuant KV cache support and optimized turbo mat-vec paths#140

Open
Fenix46 wants to merge 5 commits into
TheTom:feature/turboquant-kv-cachefrom
Fenix46:turbo/vulkan
Open

vulkan: add TurboQuant KV cache support and optimized turbo mat-vec paths#140
Fenix46 wants to merge 5 commits into
TheTom:feature/turboquant-kv-cachefrom
Fenix46:turbo/vulkan

Conversation

@Fenix46
Copy link
Copy Markdown

@Fenix46 Fenix46 commented May 10, 2026

Summary

This PR adds Vulkan backend support for TurboQuant KV cache formats and wires the missing optimized execution paths for TurboQuant types.

The main goal is to make TurboQuant KV cache usable on Vulkan for long-context inference while avoiding unnecessary fallback paths during decode and flash-attention execution.

Included changes:

  • add Vulkan support for TurboQuant KV cache formats:
    • TURBO2_0
    • TURBO3_0
    • TURBO4_0
    • TQ3_1S
    • TQ4_1S
  • add / wire Vulkan dequant shaders for TurboQuant formats
  • add flash-attention support for TURBO2_0, TURBO3_0, and TURBO4_0
  • fix copy_to_quant / copy_from_quant handling for TurboQuant types
  • fix the Vulkan SET_ROWS workgroup configuration for TURBO2_0 and TURBO4_0
  • optimize TurboQuant dequantize4() in flash_attn_base.glsl
  • register dedicated mul_mat_vec / mul_mat_vec_id Vulkan pipelines for TurboQuant types

Motivation

Before this PR, the Vulkan TurboQuant path was incomplete for KV-cache usage.

In particular:

  • some TurboQuant KV formats did not have the required Vulkan shader support;
  • TURBO2_0 and TURBO4_0 could fall through to the wrong SET_ROWS workgroup size, producing incomplete WHT output and corrupting KV cache entries;
  • the decode path could miss the dedicated quantized mul_mat_vec pipelines and fall back to a slower dequantize-then-matmul path;
  • flash attention was doing more per-element buffer loads than necessary in the TurboQuant dequant path.

This PR makes the Vulkan TurboQuant KV-cache path more complete, fixes correctness issues, and reduces avoidable overhead in the attention/decode hot paths.

Details

TurboQuant KV cache support

Adds Vulkan shader support and generated shader declarations for the TurboQuant KV-cache path, including dequantization, copy-to-quant, copy-from-quant, flash-attention integration, and required shader registration.

Supported internal formats in this PR:

  • TURBO2_0
  • TURBO3_0
  • TURBO4_0
  • TQ3_1S
  • TQ4_1S

User-facing CLI cache type names:

  • turbo2
  • turbo3
  • turbo4
  • tq3_1s
  • tq4_1s

Correctness fix for SET_ROWS

The WHT butterfly path in copy_to_quant.comp used local_size_x = 128, but the condition was only enabled for DATA_A_TURBO3_0.

As a result, TURBO2_0 and TURBO4_0 could fall back to the generic 32-thread path, producing incomplete 32/128 WHT output and corrupting KV-cache entries.

This PR extends the condition to cover all three TurboQuant types:

  • DATA_A_TURBO2_0
  • DATA_A_TURBO3_0
  • DATA_A_TURBO4_0

Flash-attention dequant optimization

The TurboQuant dequantize4() path in flash_attn_base.glsl now hoists shared loads out of the per-element loop.

Since iqs is always a multiple of 4, the four elements handled by dequantize4() share the same packed qs byte, and for TURBO3_0 also the same signs byte.

This reduces repeated buffer loads in the Q*K inner loop and lowers cache pressure during flash attention.

Dedicated Vulkan mat-vec pipelines

This PR registers dedicated Vulkan mul_mat_vec and mul_mat_vec_id pipelines for:

  • TURBO2_0
  • TURBO3_0
  • TURBO4_0

Without this, the decode path could miss the optimized quantized kernels and fall back to a slower dequantize-then-matmul route.

Testing

Tested locally with Vulkan backend using TurboQuant KV cache enabled on AMD RX570 8GB.

It works perfectly; I haven't noticed any token generation issues on any of the Turbo variants. However, without native FP16 on the card, I can't fully validate the Turbo4 tests.

What I can add is that Turbo2 is faster, even though everything is computed in FP32.

Suggested build test:

cmake -B build -DGGML_VULKAN=ON
cmake --build build --config Release -j

Fenix46 and others added 4 commits May 10, 2026 09:34
- Add dequant shaders for turbo2_0, turbo4_0, tq3_1s with WHT/RHT
- Add mul_mat_vec shader for tq3_1s
- Add flash attention support for turbo2_0, turbo3_0, turbo4_0
- Fix copy_to_quant/copy_from_quant for TurboQuant types
- Fix dequant_funcs_cm2.glsl typo (grid -> g2)
- Fix vulkan-shaders-gen: use vulkan1.3 target for _cm1/_int8/q8_1
- Add turbo2_0/turbo4_0 to FA scalar and cm1 shader generation
- Add pre-built ggml-vulkan-shaders.hpp with all new shader externs

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The local_size_x = 128 WHT butterfly path in copy_to_quant.comp was
gated on DATA_A_TURBO3_0 only. TURBO2_0 and TURBO4_0 fell through to
the generic 32-thread path, producing an incomplete 32-of-128 WHT and
corrupting the KV cache entries for those types.

Fix: extend the condition to cover all three turbo types.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Hoist the norm/qs/signs loads out of the per-element loop in
dequantize4 for all three turbo types. Since iqs is always a multiple
of 4, all four elements within a dequantize4 call share the same qs
byte (and the same signs byte for turbo3_0). This reduces the number
of buffer loads from 4x2-4x3 per call down to 2-3, lowering cache
pressure in the Q*K inner loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Wire up pipeline_dequant_mul_mat_vec_f32_f32, _f16_f32, and _id_f32
for GGML_TYPE_TURBO2_0/3_0/4_0 and add them to the switch in
ggml_vk_get_dequantize_mul_mat_vec() and
ggml_vk_get_dequantize_mul_mat_vec_id(). Without this the decode path
fell back to a slower dequantize-then-matmul route instead of the
dedicated quantized kernel.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@nikp123
Copy link
Copy Markdown

nikp123 commented May 11, 2026

I've tested this on an RX 580 8GB and it seems to work while having 2-3x performance gain on text ingestion. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants