Skip to content

[Issue]: [Critical/Correctness] Unsafe Assumptions on PageSize, LDS Layout, and Async Synchronization in FMHA Pipeline #3712

@red1239109-cmd

Description

@red1239109-cmd

Problem Description

Hello
While analyzing the block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp and associated pipeline logic to optimize for custom tile configurations, I identified several architectural risks that could lead to Silent Data Corruption (SDC) or Race Conditions.

The current implementation seems to rely on implicit assumptions (e.g., power-of-2 page sizes, specific LDS layouts) that are not enforced by static_assert, making the kernel fragile to configuration changes.

Below is a detailed analysis of the 6 critical issues found:

  1. Implicit Power-of-2 Assumption on kPageBlockSize
    In kv_offset_array_transform, the code calculates page indices using bitwise operations:

C++
// Logic assumes kPageBlockSize is 2^N
page_id = global_token_idx >> kLog2PageSize;
token_idx_in_page = global_token_idx & ((1<<kLog2PageSize)-1);
Problem: This optimization is mathematically valid only if kPageBlockSize is a power of 2. If a user configures a page size like 48, 80, or 96, this logic produces incorrect page lookups without any compilation error.

Impact: Silent OOB reads/writes.

Suggestion: Add static_assert((kPageBlockSize & (kPageBlockSize-1)) == 0, "Page size must be power of 2") or implement a fallback path using integer division/modulo.

  1. Insufficient Logic for kVTileCrossesPages
    The check kVTileCrossesPages = (page_size > 1) && (page_size % kN0 != 0) is used to determine if a tile crosses page boundaries.

Problem: This heuristic relies solely on kN0 and does not account for complex Vectorized Layouts (where K decomposition might be {K2, K0, K1}). The actual memory access pattern depends on the thread distribution in Y-space.

Impact: False negatives in this check will lead to incorrect offset calculations for threads that actually cross page boundaries.

  1. Unsafe readfirstlane Optimization
    When !kVTileCrossesPages is assumed, the code uses __builtin_amdgcn_readfirstlane to broadcast the page index from lane 0 to the entire wave.

Problem: This optimization is only safe if all lanes in the wave are guaranteed to be in the same page. If the assumption in Point 2 is incorrect (due to layout/distribution mismatch), this broadcast forces all threads to access the page of lane 0.

Impact: Wave-wide data corruption where threads read/write to the wrong page.

Suggestion: Restrict this optimization to strictly proven cases or add debug asserts verifying all_lanes_same_page_id.

  1. Potential LDS Pointer Aliasing between K and V
    The code initializes LDS pointers from the same base address:

C++
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto v_lds = make_tensor_view(reinterpret_cast<VDataType*>(smem_ptr), ...);
Problem: There is no explicit offset arithmetic visible in this scope to separate K and V regions. Unless Policy or Descriptor guarantees an internal offset for V, these pointers will alias.

Impact: V stores overwriting K data (or vice versa), causing race conditions that are extremely difficult to debug in an async pipeline.

Suggestion: Explicitly separate pointers (e.g., smem_ptr + k_offset) or document/verify that MakeVLdsBlockDescriptor handles this safely.

  1. Fragile Async Barrier Conditions
    The synchronization logic depends on specific buffer sequence patterns:

C++
if constexpr(k1_loops >= 2 && LdsSeq.at(0) == LdsSeq.at(k0_loops + k1_loops - 2))
__builtin_amdgcn_s_barrier();
Problem: This condition is tightly coupled to the implementation details of LdsSeq (Policy). Changes to the policy or different loop counts could inadvertently remove the necessary barrier, creating a hazard.

Suggestion: Generalize the hazard detection logic based on buffer index reuse, rather than hardcoded sequence comparisons.

  1. Potential Off-by-One Error in update_v_offsets
    The pipeline mixes load_tile, move_tile_window, and update_v_offsets across main loops and branches (e.g., k1_loops).

Problem: In complex scenarios (especially with outer iterations K2), the split logic for updating offsets creates a high risk of utilizing stale offsets for the next load (off-by-one tile error).

Suggestion: Standardize the update timing (e.g., always update immediately after load) to ensure consistency across all branches.

Summary: These issues represent fundamental correctness flaws rather than simple bugs. I strongly recommend a comprehensive review of the memory layout assumptions and synchronization logic in the async pipeline.


Appendix: Suggested Fix Implementation (Reference)

To expedite the resolution, here are the recommended patch logics for the identified critical issues. Please integrate these or equivalent fixes into the codebase.

1. Fix Page Size Assumption (Bitwise vs Arithmetic)

The current bitwise logic is unsafe for non-power-of-2 page sizes.
Recommendation: Add a compile-time guard or a runtime fallback.

// inside kv_offset_array_transform start
static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0,
    "kPageBlockSize must be power-of-two for current shift/mask logic. "
    "Please implement arithmetic fallback for generic page sizes.");

// OR implement fallback:
const index_t page_id =
    (kPageBlockSize & (kPageBlockSize - 1)) == 0
        ? (global_token_idx >> kLog2PageSize)
        : (global_token_idx / kPageBlockSize);

const index_t token_idx_in_page =
    (kPageBlockSize & (kPageBlockSize - 1)) == 0
        ? (global_token_idx & kInPageOffsetMask)
        : (global_token_idx % kPageBlockSize);

2. Fix LDS Pointer Aliasing (Split K and V)

Currently, k_lds_ptr and v_lds alias the same base pointer. They must be explicitly separated.
Recommendation: Offset the V pointer by the size of K.

// In Operator():
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);

// Explicitly calculate K LDS size (Implementation required in Policy)
constexpr index_t kLdsSizeK = Policy::template GetSmemSizeK<Problem>();

// V starts after K
auto v_lds_ptr = reinterpret_cast<VDataType*>(
    reinterpret_cast<char*>(smem_ptr) + kLdsSizeK);

auto v_lds = make_tensor_view<address_space_enum::lds>(
    v_lds_ptr,
    Policy::template MakeVLdsBlockDescriptor<Problem>());

3. Guard readfirstlane Broadcast (Debug Safety)

The broadcast optimization relies on the assumption that the entire wave resides on the same page, which relies on the potentially flawed kVTileCrossesPages check.
Recommendation: Add a debug guard.

#ifdef CK_TILE_DEBUG_PAGE_CHECK
const index_t my_page_id =
    (global_seq_offset + thread_coord_start + kLoopStart) >> kLog2PageSize;
const index_t lane0_page_id =
    __builtin_amdgcn_readfirstlane(my_page_id);

if(my_page_id != lane0_page_id) {
    // Fallback to per-lane calculation if assumption fails
    // ... (Execute safe path)
    // Or trigger assert/trap
}
#endif

4. Harden Async Barriers

The current barrier condition is too specific to LdsSeq patterns.
Recommendation: Use a conservative barrier to prevent any K/V hazard in the async pipeline.

if constexpr(k1_loops >= 2)
{
    // Conservatively ensure no LDS hazard between K and V prefetch stages
    // Removing dependency on specific LdsSeq indices
    __builtin_amdgcn_s_barrier();
}

These patches aim to ensure correctness and determinism over aggressive micro-optimizations that compromise safety.


Update: Additional Critical Findings in Policy Implementation

Upon further inspection of block_fmha_pipeline_qx_ks_vs_custom_policy.hpp and ..._default_policy.hpp, I have identified 4 more specific implementation flaws that confirm the structural risks mentioned above.

**4. Confirmed Unsafe Memory Reuse via max(K,V)**

The function GetSingleSmemElementSpaceSize uses max(SingleKSize, SingleVSize), confirming that K and V share the same physical memory slots.

  • The Hazard: In the async pipeline, prefetching V into a slot while GEMM0 is still reading K from that same slot creates a race condition. The kernel relies on loose barriers (sched_barrier) rather than rigorous wait-states to separate these phases.
  • Action: Verify that pipeline depth strictly prevents overlap, or disable reuse for safety.

5. Brittle LdsBufferSequence Specialization

The pipeline relies on manual template specializations for LdsBufferSequence to avoid memory hazards.

  • The Risk: Any configuration (e.g., NumPrefetch=4) without a matching specialization will fallback to a default sequence that may violate the "overlap avoidance rule", leading to Silent Data Corruption.
  • Recommendation: Replace manual specialization with constexpr logic that mathematically guarantees non-overlapping indices.

7. Implicit Integer Division in Distribution

In MakeVDramTileDistribution, KOuterIter is calculated as kKPerBlock / KPerIter.

  • The Flaw: There is no assertion ensuring kKPerBlock is a multiple of KPerIter. Non-divisible configurations will silently truncate iteration counts.
  • Recommendation: Add static_assert(kKPerBlock % KPerIter == 0).

8. Logic Error (Typo) in NumPrefetchV Assignment

CRITICAL: In the policy definition, NumPrefetchV is incorrectly assigned NumPrefetchK_ instead of NumPrefetchV_.

// Current Code:
static constexpr index_t NumPrefetchV = NumPrefetchK_; 
  • Impact: Users cannot independently control V prefetch depth. If NumPrefetchV was intended to be smaller to save LDS, this logic forces it to match K, causing memory waste or pipeline stalls.
  • Recommendation: Correct the assignment to NumPrefetchV_.

Best regards.

Operating System

Ubuntu 22.04 LTS

CPU

AMD Ryzen 9 7950X

GPU

AMD Radeon RX 7900 XT

Other

Issue found via Static Code Analysis (Logic verification), unrelated to specific hardware.

ROCm Version

ROCm 6.0.0

ROCm Component

Composable Kernel (ck_tile) / FMHA

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions