Skip to content

[MI450][Kernel] add Deepseek MHA bf16 kernel verified on MI450#393

Open
jli-melchior wants to merge 3 commits intomainfrom
jli/mha-gfx1250
Open

[MI450][Kernel] add Deepseek MHA bf16 kernel verified on MI450#393
jli-melchior wants to merge 3 commits intomainfrom
jli/mha-gfx1250

Conversation

@jli-melchior
Copy link
Copy Markdown
Contributor

Deepseek MHA bf16 kernel verified on MI450

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

…de style

Refactor machine-generated MHA kernel (1850→1633 lines) to align with
standard gfx1250 kernel conventions:
- Consolidate 20+ inline imports to top-level with consistent aliases
- Replace hand-rolled _SCFForCtx/_scf_yield_/_to_index with built-in
  range(init=...) + yield syntax
- Replace 128-line hand-unrolled extract+add with _tree_reduce calls
- Remove redundant _idx2crd wrapper (fx.idx2crd handles scalar ir.Value)
- Clean up duplicate imports in _shuffle_* helpers
- Reorganize constants with descriptive names

Verified: test passes with identical max_err=0.000468
Replace low-level arith.* function calls with ArithValue method syntax:
- arith.bitcast(T.xxx, val) → val.bitcast(T.xxx) (32 sites)
- arith.truncf(T.xxx, val) → val.truncf(T.xxx) (4 sites)
- arith.extsi(T.xxx, val) → val.extsi(T.xxx) (2 sites)
- arith.select(cond, a, b) → cond.select(a, b) (3 sites)
- arith.index_cast(T.xxx, val) → val.index_cast(T.xxx) (10 sites)

Retained explicit arith.index_cast in helper functions where values
may be raw ir.Value from MLIR dialect ops (not ArithValue).

Verified: test passes with identical max_err=0.000468
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant