Looks the counter which should be updated in while loop is not updated correctly and LLVM reports error:
Traceback (most recent call last):
File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 996, in <module>
ret = test_mla(
File "/jruan/ws/aiter/aiter/test_common.py", line 128, in wrapper
ret = func(*args, **kwargs)
File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 833, in test_mla
err, us_asm_decode = test_absorb_decode_fp8()
File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 708, in test_absorb_decode_fp8
attn_logits, attn_lse = aiter.mla.mla_decode_fwd(
File "/jruan/ws/aiter/aiter/mla.py", line 379, in mla_decode_fwd
flydsl_attn_reduce_v1(
File "/jruan/ws/aiter/aiter/ops/flydsl/attn_reduce.py", line 155, in flydsl_attn_reduce_v1
launch_attn_reduce_ps(
File "/jruan/ws/FlyDSL/python/flydsl/compiler/jit_function.py", line 544, in __call__
compiled_module = MlirCompiler.compile(module, chip=chip, func_name=self.func.__name__)
File "/jruan/ws/FlyDSL/python/flydsl/compiler/jit_function.py", line 312, in compile
module.operation.verify()
flydsl._mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: unknown: 'scf.while' op expects the 'after' region to terminate with 'scf.yield'
note: unknown: see current operation:
%69:2 = "scf.while"(%68, %36) ({
^bb0(%arg19: i32, %arg20: i32):
%213 = "arith.cmpi"(%68, %36) <{predicate = 2 : i64}> : (i32, i32) -> i1
"scf.condition"(%213, %arg19, %arg20) : (i1, i32, i32) -> ()
}, {
^bb0(%arg11: i32, %arg12: i32):
"gpu.barrier"() : () -> ()
%70 = "arith.constant"() <{value = 128 : i32}> : () -> i32
%71 = "arith.remsi"(%68, %70) : (i32, i32) -> i32
%72 = "arith.constant"() <{value = 128 : i32}> : () -> i32
%73 = "arith.floordivsi"(%68, %72) : (i32, i32) -> i32
%74 = "arith.constant"() <{value = 1 : i32}> : () -> i32
%75 = "arith.remsi"(%73, %74) : (i32, i32) -> i32
%76 = "arith.constant"() <{value = 128 : i32}> : () -> i32
%77 = "arith.floordivsi"(%68, %76) : (i32, i32) -> i32
%78 = "arith.constant"() <{value = 1 : i32}> : () -> i32
%79 = "arith.floordivsi"(%77, %78) : (i32, i32) -> i32
%80 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%81 = "arith.muli"(%79, %80) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%82 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%83 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%84 = "rocdl.raw.ptr.buffer.load"(%4, %81, %82, %83) : (!llvm.ptr<8>, i32, i32, i32) -> vector<2xi32>
%85 = "vector.extract"(%84) <{static_position = array<i64: 0>}> : (vector<2xi32>) -> i32
%86 = "vector.extract"(%84) <{static_position = array<i64: 1>}> : (vector<2xi32>) -> i32
%87 = "arith.cmpi"(%85, %41) <{predicate = 1 : i64}> : (i32, i32) -> i1
%88 = "arith.subi"(%86, %85) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%89 = "arith.constant"() <{value = 1 : i32}> : () -> i32
%90 = "arith.cmpi"(%88, %89) <{predicate = 4 : i64}> : (i32, i32) -> i1
"scf.if"(%90) ({
%94 = "memref.get_global"() <{name = @smem_storage}> : () -> memref<8192xi8, #gpu.address_space<workgroup>>
%95 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
%96 = "arith.subi"(%86, %85) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%97 = "arith.index_cast"(%96) : (i32) -> index
%98 = "arith.index_cast"(%85) : (i32) -> index
%99 = "arith.constant"() <{value = 128 : index}> : () -> index
"scf.for"(%95, %97, %99) ({
^bb0(%arg18: index):
%203 = "arith.cmpi"(%arg18, %97) <{predicate = 6 : i64}> : (index, index) -> i1
"scf.if"(%203) ({
%204 = "arith.addi"(%98, %arg18) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
%205 = "arith.index_cast"(%204) : (index) -> i32
%206 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%207 = "arith.muli"(%205, %206) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%208 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%209 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%210 = "rocdl.raw.ptr.buffer.load"(%14, %207, %208, %209) : (!llvm.ptr<8>, i32, i32, i32) -> i32
%211 = "arith.constant"() <{value = 0 : index}> : () -> index
%212 = "memref.view"(%94, %211) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
"memref.store"(%210, %212, %arg18) : (i32, memref<2048xi32, #gpu.address_space<workgroup>>, index) -> ()
"scf.yield"() : () -> ()
}, {
}) : (i1) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"gpu.barrier"() : () -> ()
%100 = "arith.constant"() <{value = 0 : index}> : () -> index
%101 = "arith.constant"() <{value = 0 : index}> : () -> index
%102 = "memref.view"(%94, %101) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
%103 = "memref.load"(%102, %100) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
%104 = "arith.constant"() <{value = 1 : index}> : () -> index
%105 = "memref.load"(%102, %104) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
%106 = "arith.constant"() <{value = 2 : i32}> : () -> i32
%107 = "arith.muli"(%79, %106) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%108 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%109 = "arith.muli"(%107, %108) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%110 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%111 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%112 = "rocdl.raw.ptr.buffer.load"(%9, %109, %110, %111) : (!llvm.ptr<8>, i32, i32, i32) -> vector<2xi32>
%113 = "vector.extract"(%112) <{static_position = array<i64: 0>}> : (vector<2xi32>) -> i32
%114 = "vector.extract"(%112) <{static_position = array<i64: 1>}> : (vector<2xi32>) -> i32
%115 = "arith.index_cast"(%95) : (index) -> i32
%116 = "arith.addi"(%113, %75) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%117 = "arith.index_cast"(%116) : (i32) -> index
%118 = "arith.index_cast"(%114) : (i32) -> index
%119 = "arith.constant"() <{value = 1 : index}> : () -> index
%120 = "arith.constant"() <{value = 1 : index}> : () -> index
"scf.for"(%117, %118, %119) ({
^bb0(%arg13: index):
%121 = "arith.index_cast"(%arg13) : (index) -> i32
%122 = "arith.subi"(%121, %113) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%123 = "arith.addi"(%103, %122) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%124 = "arith.constant"() <{value = 128 : i32}> : () -> i32
%125 = "arith.muli"(%123, %124) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%126 = "arith.addi"(%125, %71) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%127 = "arith.constant"() <{value = 512 : i32}> : () -> i32
%128 = "arith.muli"(%125, %127) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%129 = "arith.constant"() <{value = 512 : i32}> : () -> i32
%130 = "arith.muli"(%71, %129) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%131 = "arith.addi"(%128, %130) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%132 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%133 = "arith.muli"(%115, %132) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%134 = "arith.addi"(%131, %133) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%135 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%136 = "arith.muli"(%134, %135) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%137 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%138 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%139 = "rocdl.raw.ptr.buffer.load"(%34, %136, %137, %138) : (!llvm.ptr<8>, i32, i32, i32) -> vector<4xf32>
%140 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%141 = "arith.muli"(%126, %140) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%142 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%143 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%144 = "rocdl.raw.ptr.buffer.load"(%29, %141, %142, %143) : (!llvm.ptr<8>, i32, i32, i32) -> f32
%145 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
%146 = "arith.index_cast"(%85) : (i32) -> index
%147 = "arith.index_cast"(%86) : (i32) -> index
%148 = "arith.addi"(%146, %120) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
%149:3 = "scf.for"(%148, %147, %120, %139, %144, %145) ({
^bb0(%arg14: index, %arg15: vector<4xf32>, %arg16: f32, %arg17: f32):
%165 = "arith.subi"(%arg14, %146) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
%166 = "arith.constant"() <{value = 0 : index}> : () -> index
%167 = "memref.view"(%94, %166) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
%168 = "memref.load"(%167, %165) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
%169 = "arith.addi"(%168, %122) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%170 = "arith.constant"() <{value = 128 : i32}> : () -> i32
%171 = "arith.muli"(%169, %170) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%172 = "arith.addi"(%171, %71) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%173 = "arith.constant"() <{value = 512 : i32}> : () -> i32
%174 = "arith.muli"(%171, %173) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%175 = "arith.constant"() <{value = 512 : i32}> : () -> i32
%176 = "arith.muli"(%71, %175) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%177 = "arith.addi"(%174, %176) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%178 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%179 = "arith.muli"(%115, %178) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%180 = "arith.addi"(%177, %179) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%181 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%182 = "arith.muli"(%180, %181) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%183 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%184 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%185 = "rocdl.raw.ptr.buffer.load"(%34, %182, %183, %184) : (!llvm.ptr<8>, i32, i32, i32) -> vector<4xf32>
%186 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%187 = "arith.muli"(%172, %186) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%188 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%189 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%190 = "rocdl.raw.ptr.buffer.load"(%29, %187, %188, %189) : (!llvm.ptr<8>, i32, i32, i32) -> f32
%191 = "arith.maximumf"(%arg16, %190) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%192 = "arith.subf"(%arg16, %191) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%193 = "math.exp"(%192) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
%194 = "arith.subf"(%190, %191) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%195 = "math.exp"(%194) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
%196 = "vector.broadcast"(%193) : (f32) -> vector<4xf32>
%197 = "arith.mulf"(%196, %arg15) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%198 = "vector.broadcast"(%195) : (f32) -> vector<4xf32>
%199 = "arith.mulf"(%198, %185) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%200 = "arith.addf"(%197, %199) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%201 = "arith.mulf"(%arg17, %193) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%202 = "arith.addf"(%201, %195) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
"scf.yield"(%200, %191, %202) : (vector<4xf32>, f32, f32) -> ()
}) : (index, index, index, vector<4xf32>, f32, f32) -> (vector<4xf32>, f32, f32)
%150 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
%151 = "arith.divf"(%150, %149#2) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%152 = "vector.broadcast"(%151) : (f32) -> vector<4xf32>
%153 = "arith.mulf"(%152, %149#0) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
%154 = "arith.muli"(%121, %arg7) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%155 = "arith.muli"(%71, %arg8) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%156 = "arith.addi"(%154, %155) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%157 = "arith.constant"() <{value = 4 : i32}> : () -> i32
%158 = "arith.muli"(%115, %157) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%159 = "arith.addi"(%156, %158) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%160 = "arith.truncf"(%153) : (vector<4xf32>) -> vector<4xbf16>
%161 = "arith.constant"() <{value = 2 : i32}> : () -> i32
%162 = "arith.muli"(%159, %161) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
%163 = "arith.constant"() <{value = 0 : i32}> : () -> i32
%164 = "arith.constant"() <{value = 0 : i32}> : () -> i32
"rocdl.raw.ptr.buffer.store"(%160, %24, %162, %163, %164) : (vector<4xbf16>, !llvm.ptr<8>, i32, i32, i32) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"scf.yield"() : () -> ()
}, {
}) : (i1) -> ()
%91 = "gpu.grid_dim"() <{dimension = #gpu<dim x>}> : () -> index
%92 = "arith.index_cast"(%91) : (index) -> i32
%93 = "arith.addi"(%68, %92) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
}) : (i32, i32) -> (i32, i32)
note: "/jruan/ws/aiter/aiter/ops/flydsl/kernels/attn_reduce.py":360:0: terminator here
Problem Description
Looks the counter which should be updated in while loop is not updated correctly and LLVM reports error:
While loop code:
https://github.com/ROCm/aiter/blob/11d0b63e1412c6f771d0d9d0225cdc07fb6ecd16/aiter/ops/flydsl/kernels/attn_reduce.py#L439
Operating System
all
CPU
all
GPU
all
ROCm Version
all
ROCm Component
No response
Steps to Reproduce
jruan/fdsl_issue_211(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response