Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
270 commits
Select commit Hold shift + click to select a range
537d19c
Added packed instances for bf16xi4xbf16
ancahamuraru May 15, 2025
e97a7f2
Added padding instances for f8xf16xf16
ancahamuraru May 15, 2025
bbda71f
Added padding instances for f16xf8xf16, f16xi4xf16
ancahamuraru May 16, 2025
a0a2bf2
Fixed typos for bf16xbf16xbf16 padding instances
ancahamuraru May 19, 2025
7975c9c
Fixed typos for padded instances
ancahamuraru May 19, 2025
dc26ee3
Added tests for fp16, KM_KN and KM_NK
ancahamuraru May 20, 2025
a08ca63
Padding not supported for when BDataType is pk_i4_t. Added fix for co…
ancahamuraru May 20, 2025
0a5e6d4
Fixed typos
ancahamuraru May 20, 2025
b350bd2
Updated the set of tests for FP16
ancahamuraru May 20, 2025
ae21582
Updated the set of tests for FP16
ancahamuraru May 20, 2025
185ea0f
Fix typo
ancahamuraru May 20, 2025
b1a9a27
Merge branch '33-wip' of projects.streamhpc.com:amd/ai/composable_ker…
ancahamuraru May 20, 2025
15bfa00
Moved f16xi4 test under the correct data layout group
ancahamuraru May 20, 2025
621012c
example for gemm_universal_bf16
ApoorvaKalyani May 7, 2025
b35a195
Adding examples for gemm_wmma instances
ApoorvaKalyani May 7, 2025
8f8e631
Added the missing parameters
ApoorvaKalyani May 7, 2025
840b79d
Fixed review comments and added executable to cmakeLists
ApoorvaKalyani May 8, 2025
4b5a9ac
Fixing clang format
ApoorvaKalyani May 8, 2025
9cc5702
Fixing build erros
ApoorvaKalyani May 8, 2025
b0aa933
Fixed compilation failure.
ApoorvaKalyani May 12, 2025
c016164
Modified some code as per gemm_universal_examples
ApoorvaKalyani May 13, 2025
9d8f1e4
Fixed the gemm specialization error
ApoorvaKalyani May 13, 2025
501d957
Fixed the build errors.
ApoorvaKalyani May 15, 2025
cc818b4
Fix strides of a/b_thread_desc
ex-rzr Apr 28, 2025
af9e9ed
Load in M/NRepeat dims with thread copy's slice instead of a loop
ex-rzr Apr 28, 2025
ede7126
Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation
ex-rzr Apr 28, 2025
c414097
Implement Intrawave and Interwave variants of pipeline v1
ex-rzr Apr 30, 2025
c94c3b4
Add instances for Interwave and Intrawave v1
ex-rzr May 16, 2025
04d3fc7
Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0
ex-rzr May 16, 2025
17bc0fa
Remove instances that are too slow (mostly because of register spilling)
ex-rzr May 19, 2025
342bb57
Add a workaround for fp8/bf8->f32 packed conversion issue
ex-rzr May 20, 2025
5082a9c
Add instances for Interwave and Intrawave v1
ex-rzr May 20, 2025
c7d39a0
Enable profiling of mixed precision with f8 and int4 on WMMA
ex-rzr May 20, 2025
8b5d340
Fix segfault in profiler when B is pk_i4_t
ex-rzr May 21, 2025
b1f50b5
Remove instances that are too slow (mostly because of register spilling)
ex-rzr May 21, 2025
02bf56a
Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations
ex-rzr May 21, 2025
dd7ac95
Add test case for bf16_i4
ex-rzr May 21, 2025
eac7d35
Add missing Regular tests
ex-rzr May 21, 2025
05ad214
Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS
ex-rzr May 22, 2025
83b1419
Fix a bug that fp16_i4 validation passes only with PermuteB
ex-rzr May 22, 2025
9e70603
Use PermuteB with f16_i4 in most instances (as xdl)
ex-rzr May 22, 2025
c143bf3
Fix cache flushing for pk_i4
ex-rzr May 22, 2025
668914c
Add mixed precision examples
ex-rzr May 22, 2025
2679c0a
Disable all tests and instances with f8 on gfx11
ex-rzr May 23, 2025
a6ea604
Add FP16 KM_NK and KM_KN test suites for XDL
ex-rzr May 23, 2025
da5f962
Support multiple D in GridwiseGemm_wmma_cshuffle_v3
ex-rzr May 29, 2025
99fc05e
Use ThreadGroupTensorSliceTransfer_v7r3
ex-rzr May 29, 2025
a038ba3
Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
ex-rzr May 29, 2025
5151206
Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for…
ex-rzr May 29, 2025
f13b913
Implement DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
db51d8a
Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
22935c8
Prepare gemma_add tests for adding wmma
ex-rzr Jun 2, 2025
25f7204
Add gemm_add_fastgelu instances and test
ex-rzr Jun 2, 2025
959defb
Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with…
ex-rzr Jun 2, 2025
b8e45c7
removed unnecessary ck parts from compilation
May 30, 2025
538fa87
initial gemm_add_multiply instance implementations
May 30, 2025
8727762
fixed profiler help message for gemm_add_multiply
May 30, 2025
63513c3
improved multiply_add profiler layout help
May 30, 2025
07f75d9
fixed template arguments for test instances
Jun 2, 2025
75550ff
added test for gemm_add_multiply
Jun 3, 2025
ed047d0
Support multiple D in GridwiseGemm_wmma_cshuffle_v3
ex-rzr May 29, 2025
deebe1e
Use ThreadGroupTensorSliceTransfer_v7r3
ex-rzr May 29, 2025
7dff5fe
Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support
ex-rzr May 29, 2025
89ac60d
Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for…
ex-rzr May 29, 2025
137efa7
Implement DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
e36a176
Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3
ex-rzr May 29, 2025
bcf93e2
Prepare gemma_add tests for adding wmma
ex-rzr Jun 2, 2025
381c02d
Add gemm_add_fastgelu instances and test
ex-rzr Jun 2, 2025
9912e5f
Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with…
ex-rzr Jun 2, 2025
881bc3f
Merge branch '52-implement-multipled-in-gemm-universal' into 'feature…
ex-rzr Jun 4, 2025
4e07085
switched to splitK interface
Jun 4, 2025
8658ca6
log print added to splitk benchmarks
Jun 5, 2025
a902c57
revert main cmake comments
Jun 5, 2025
32e78b6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jun 5, 2025
0228eca
newline change reverted
Jun 6, 2025
ea9805b
added add_fastgelu instances
Jun 10, 2025
aeca8ef
revert unintended change in xdl add_fastgelu
Jun 11, 2025
4c8ea95
created gemm_add_add_fastgelu instances
Jun 11, 2025
264e1b2
created fastegelu instances
Jun 11, 2025
b4d3e41
added tests for all splitk fastgelus
Jun 12, 2025
0696f99
Added tests.
ApoorvaKalyani Jun 13, 2025
a529e3e
multiply_add instances created
Jun 13, 2025
27d86a3
updates to add_multiply splitk instances
Jun 13, 2025
61b6e9a
splitk xdl test fixes
Jun 13, 2025
ac60286
added wmma multiply_multiply instances
Jun 17, 2025
7424b4a
fixed ONLY_XDL_AND_WMMA_KERNELS tag
Jun 17, 2025
30d65b9
Added gemm_add examples for wmma v1 and v3
ApoorvaKalyani Jun 18, 2025
90c9b09
Merge branch '61-add-examples-for-bf16-and-fp16-instances-of-gemm_add…
ApoorvaKalyani Jun 18, 2025
cd0172b
fixed / workarounded i8 instances
Jun 18, 2025
055bc02
Merge branch '10-implement-device_gemm_add_fastgelu-for-rdna4' into '…
Jun 18, 2025
40ce862
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 6…
Jun 18, 2025
c2077ca
Modified the v3 code to added one fp16 bxdl instance.
ApoorvaKalyani Jun 18, 2025
57c3fd9
added bf16 xdl instance.
ApoorvaKalyani Jun 18, 2025
b42b6b6
adding gemm_add wmma_cshuffle and other support
ApoorvaKalyani May 26, 2025
113ea09
add instances into camkelists
ApoorvaKalyani May 26, 2025
b129e73
This is work in progress, edited the template parameters in order to …
ApoorvaKalyani May 26, 2025
455275d
temp work saved, changed the BDataType to f16 or bf16 since wmma curr…
ApoorvaKalyani May 26, 2025
1fda499
added datatype and use clang-format-12
ApoorvaKalyani May 26, 2025
1519eaa
Fixing build errors
ApoorvaKalyani May 28, 2025
32b9500
Added instances for v3
ApoorvaKalyani Jun 11, 2025
7da9f64
Adding instances and executables
ApoorvaKalyani Jun 11, 2025
0cce81c
Code update of template parameters modified.
ApoorvaKalyani Jun 12, 2025
6df313f
Renamed file.
ApoorvaKalyani Jun 12, 2025
06d44f1
Added tests.
ApoorvaKalyani Jun 13, 2025
10d648a
resolved error tests.
ApoorvaKalyani Jun 13, 2025
ef781db
Fixing build errors
ApoorvaKalyani Jun 13, 2025
bd49ec0
Updated comments
ApoorvaKalyani Jun 13, 2025
3301ef5
removed the changes as per the MR review comment.
ApoorvaKalyani Jun 19, 2025
38d0027
Updated tests.
ApoorvaKalyani Jun 19, 2025
5e45427
fp8 instances - not tested
Jun 19, 2025
c8b3f3d
Restored the Cmake file that was reverted by mistake during rebase.
ApoorvaKalyani Jun 19, 2025
a8dec7a
fixed wmma_op test
Jun 19, 2025
78c2ee2
Updated comments.
ApoorvaKalyani Jun 19, 2025
ed5ac21
Updated the template parameter description
ApoorvaKalyani Jun 19, 2025
1c01ff6
fixed rdna4 instances
Jun 23, 2025
fb4c1b5
fixed back compatibility on gfx11
Jun 23, 2025
d7b4d51
cleanups
Jun 24, 2025
94f543c
fix ckProfiler
Jun 24, 2025
8b694c3
one more cmake fix
Jun 24, 2025
3c3136b
added fp8 instances
Jun 24, 2025
71d65d4
Updated tests to ad BF16 instances as per review comment
ApoorvaKalyani Jun 25, 2025
ee8c278
Added include file and cleaned up(as per review comment)
ApoorvaKalyani Jun 25, 2025
7840db4
Updated and optimized the example code for all types.
ApoorvaKalyani Jun 25, 2025
3037858
Fixed clang format
ApoorvaKalyani Jun 25, 2025
686df33
Resolve "Implement `device_gemm_bilinear` for RDNA4"
Jun 26, 2025
4f19101
Merge branch '63-implement-device_gemm_bilinear-for-rdna4' into 'feat…
Jun 26, 2025
6ba1dc6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jun 30, 2025
eaa0452
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 6…
Jun 30, 2025
0794833
test generalization to handle FP16 shuffle better
Jun 30, 2025
bb7b307
added missing changes
Jun 30, 2025
35aab35
Added bf16 wmma instance for add_relu
ApoorvaKalyani Jun 19, 2025
6f89183
Added f16 wmma instance and corrected bf16 instance errors.
ApoorvaKalyani Jun 23, 2025
cdaff7f
Added instances to Cmake
ApoorvaKalyani Jun 24, 2025
6a116fa
Modified the template parameters to make the instances work.
ApoorvaKalyani Jul 1, 2025
bb7f665
Fixed typo in profiler
ApoorvaKalyani Jul 1, 2025
f5843dd
Added v3 instances for gemm_add_relu
ApoorvaKalyani Jul 1, 2025
ff31873
addressed core review comments
Jul 1, 2025
6ec0ad2
Added test for gemm_add_relu wmma instance
ApoorvaKalyani Jul 1, 2025
feca919
Cleaned up the code.
ApoorvaKalyani Jul 1, 2025
ba9c637
Added examples for gemm_add_relu
ApoorvaKalyani Jul 2, 2025
5c491e7
Fixing typo to resolve build errors.
ApoorvaKalyani Jul 2, 2025
8a5bb25
Fixes applied to fix the precision loss.
ApoorvaKalyani Jul 7, 2025
0551b84
fix billinear test after merge
Jul 8, 2025
86ca6b8
Removed the old wmma instances.
ApoorvaKalyani Jul 8, 2025
9b64da2
Added wrapper and renamed the wmma_v3 instances
ApoorvaKalyani Jul 8, 2025
669befb
Updated copyrights and added wrappers.
ApoorvaKalyani Jul 8, 2025
bdfdb0c
Fixes applied according to review comments
ApoorvaKalyani Jul 8, 2025
d3a26e5
Apply 1 suggestion(s) to 1 file(s)
ApoorvaKalyani Jul 8, 2025
84b0b32
Removed the old wmma instances.
ApoorvaKalyani Jul 8, 2025
516d1f5
Updated wrapper for the v3 instances
ApoorvaKalyani Jul 8, 2025
e59d281
removed the old wmma examples
ApoorvaKalyani Jul 8, 2025
566e472
Renamed the v3 instances
ApoorvaKalyani Jul 8, 2025
9655010
Deleted the gtest file added by mistake.
ApoorvaKalyani Jul 8, 2025
536f866
Updated thge profiler with wrapper
ApoorvaKalyani Jul 8, 2025
13efcc6
Fixed test errors.
ApoorvaKalyani Jul 8, 2025
55299c9
Fixed the review comments
ApoorvaKalyani Jul 9, 2025
3212507
Fixed the if condition MACROS.
ApoorvaKalyani Jul 9, 2025
21cb985
REVERTED THE PROFILER CHANGES
ApoorvaKalyani Jul 9, 2025
e1374ea
Revert "REVERTED THE PROFILER CHANGES"
ApoorvaKalyani Jul 9, 2025
9e3d87e
Revert "Fixed test errors."
ApoorvaKalyani Jul 9, 2025
ea133bf
Revert "Updated thge profiler with wrapper"
ApoorvaKalyani Jul 9, 2025
76f4bb0
Added missing wrapper instances
ApoorvaKalyani Jul 9, 2025
2738ca5
Updated copyrights.
ApoorvaKalyani Jul 9, 2025
e6ea4aa
Fixed typo.
ApoorvaKalyani Jul 9, 2025
8d64718
Fixed copyrights.
ApoorvaKalyani Jul 9, 2025
8e91755
Updated copyrights.
ApoorvaKalyani Jul 9, 2025
aea158f
updated copyrights.
ApoorvaKalyani Jul 9, 2025
a7993ab
comments on the atomics workaround
Jul 10, 2025
0dc871a
Merge branch '64-implement-device_gemm_multiply_multiply_instance-for…
Jul 11, 2025
41d4500
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jul 11, 2025
9c1314d
Merge branch '51-create-bf16-and-f16-instances-for-gemm_add-cshuffle_…
ApoorvaKalyani Jul 14, 2025
036799d
Merge branch '61-add-examples-for-bf16-and-fp16-instances-of-gemm_add…
ApoorvaKalyani Jul 14, 2025
27c0f95
Merge branch '79-add-instances-and-examples-for-device_gemm_add_relu'…
ApoorvaKalyani Jul 14, 2025
e2a75d6
Merge remote-tracking branch 'origin/feature/multiple-d-gemms' into 8…
Jul 14, 2025
161fe6c
fixed cmake comment
Jul 14, 2025
1de5d98
Merge branch '8-implement-device_gemm_add_multiply-for-rdna4' into 'f…
Jul 14, 2025
5dc21c5
Merge branch 'develop' into feature/multiple-d-gemms
EnricoDeg Jul 28, 2025
02cb1f2
Fix bug from merge
EnricoDeg Aug 4, 2025
ec38280
Merge remote-tracking branch 'origin/develop' into 90-prepare-an-upst…
krithalith Aug 6, 2025
c434378
clang-format-18
krithalith Aug 6, 2025
8f01112
Fix compilation error
EnricoDeg Aug 6, 2025
9ee5699
multi_abd wmma support:
EnricoDeg Jul 9, 2025
9b7cf3e
Fix bug in device print function
EnricoDeg Jul 23, 2025
ccf696a
Fix unused template parameter
EnricoDeg Jul 23, 2025
16920de
Add support for fwd conv in gridwise implementation. Identical to run…
krithalith Aug 20, 2025
43f99d8
Initial device implementation for grouped conv fwd multiABD wmma cshu…
krithalith Aug 20, 2025
4354cef
Make relevant profilers print the number of valid instances to aid te…
krithalith Aug 20, 2025
9089f2c
Add instances for all vanilla 2D and 3D flavors for f16 and bf16, onl…
krithalith Aug 20, 2025
f906c70
Reset output buffer after each run in profile_grouped_conv_fwd_impl().
krithalith Aug 24, 2025
b53c584
Disable sharding for the new instances for now, has tendency to lead …
krithalith Aug 24, 2025
6ad73cd
Add CTranspose optimization for NCHW cases just like in xdl cshuffle …
krithalith Aug 24, 2025
e325dab
Add instances for all 8-bit 3D vanilla grouped conv fwd types, includ…
krithalith Aug 26, 2025
ca7b312
Add int8 instances for 2D vanilla grouped conv fwd all layouts.
krithalith Aug 26, 2025
73521fe
Implement merged groups in device impl and add instances for merged g…
krithalith Aug 27, 2025
a06e276
Add merged groups instances for all 2D vanilla grouped conv fwd types…
krithalith Aug 27, 2025
68f9e73
Implement multi-AB support for grouped conv fwd and add example.
krithalith Aug 28, 2025
78635fd
Add 1D instances
krithalith Aug 28, 2025
382d6fe
Add D layout tests to IsSupportedArgument()
krithalith Aug 29, 2025
63f52e0
Add comp and mem instances for all vanilla 2D grouped conv fwd types.…
krithalith Aug 29, 2025
812b485
Add comp and mem instances for vanilla 3D grouped conv fwd. Skipped 2…
krithalith Aug 31, 2025
bc2c2fd
Add some more tests for vanilla grouped conv fwd
krithalith Sep 1, 2025
4822517
Add 2D bias clamp instances and tests
krithalith Sep 1, 2025
0b8de9a
Add 3D bias clamp instances and tests
krithalith Sep 1, 2025
9416c82
Add 2D and 3D clamp instances and tests
krithalith Sep 1, 2025
bcf9279
Unify problem sizes across vanilla and clamp flavor tests
krithalith Sep 2, 2025
52c42d5
Clean up device implementation: remove old todos, remove unnecessary …
krithalith Sep 2, 2025
b8d4b01
Implement rotating memory and flush cache. Requires ad-hoc buffer siz…
krithalith Sep 4, 2025
e7314a1
Remove wmma fp8 and bf8 instances when not targetting gfx12
krithalith Sep 5, 2025
521970c
Add newer instances to DEVICE_INSTANCES so the main ckProfiler can build
krithalith Sep 5, 2025
b9986de
Remove old years for newly created files.
krithalith Sep 15, 2025
eea8476
No need to time kernels for now.
krithalith Sep 15, 2025
e3fccf0
Fixup comments
krithalith Sep 16, 2025
a28d102
Pass struct args to Gridwise Run() function by reference.
krithalith Sep 16, 2025
a8a5504
Don't use workspace memory in the case where A needs explicit transpo…
krithalith Sep 16, 2025
58e7321
Move calculation of rotating memory buffer sizes to Argument member f…
krithalith Sep 16, 2025
a6dbb39
After the convolution to gemm transformation, the resulting 2D tensor…
krithalith Sep 18, 2025
ae3e373
Unify xdl and wmma example code for grouped conv fwd scaleadd ab
krithalith Sep 18, 2025
b26f2c6
Go back to passing RCR 2D tensor layouts to gridwise gemm, and use CR…
krithalith Sep 18, 2025
1cc3a9e
Add wmma scaleadd ab instances to the device factory and add a comple…
krithalith Sep 19, 2025
fc61c5d
Add support for V3 pipeline (tested). To be able to support num_loop …
krithalith Sep 22, 2025
67a6757
Merge remote-tracking branch 'origin/develop' into 65-grouped-conv-fw…
krithalith Sep 23, 2025
a126f5c
Small post-merge fixup, everything seems to work.
krithalith Sep 24, 2025
238218b
Do not build or run Xdl operations with Wmma backend for now. Will be…
krithalith Sep 24, 2025
f26e00e
Extend scaleadd_ab instance lists
krithalith Sep 25, 2025
ee5225f
Extend merged groups instance lists, including adaptations of xdl "2x…
krithalith Sep 25, 2025
5cc80ca
Extend "comp" instance lists, including "2x" and "part2" instances. 2…
krithalith Sep 26, 2025
2bb627f
Extend "mem" instance lists.
krithalith Sep 26, 2025
8cd5e3f
Extend regular instance lists.
krithalith Sep 29, 2025
1b9bf99
Fixup comments and ignored kernel arg name
krithalith Sep 30, 2025
28706e6
Properly use the splitN offsets for D tensors in the gridwise Run() f…
krithalith Sep 30, 2025
d0f59a5
Make sure all strides in ComputePtrOffset are at least value initiali…
krithalith Oct 1, 2025
bd00884
Re-enable sharding for wmma cshufflev3 instances
krithalith Oct 1, 2025
3b0979a
Merge remote-tracking branch 'origin/develop' into 65-grouped-conv-fw…
krithalith Oct 2, 2025
c3d5da4
Post merge fix to vanilla test
krithalith Oct 2, 2025
1d14d83
Optionally allow num_k_loop <= PrefetchStages in gridwise CheckValidi…
krithalith Oct 3, 2025
8c80631
Merge remote-tracking branch 'origin/develop' into 65-grouped-conv-fw…
krithalith Oct 3, 2025
22fb5c5
Remove spurious ck_tile changes that were presumably introduced somew…
krithalith Oct 3, 2025
5cc470b
Merge remote-tracking branch 'origin/develop' into 65-grouped-conv-fw…
krithalith Nov 10, 2025
51a4ae4
Post-merge fixes. Make sure the new gridwise gemm wmma v3 common Run …
krithalith Nov 10, 2025
08e3e9e
Disable FP8 / BF8 testing on CDNA1/2, it doesn't work anymore and nee…
krithalith Nov 18, 2025
7521983
Re-enable old wmma instances
krithalith Oct 1, 2025
c52fbb9
Re-enable Linqun's Xdl Wmma instances
krithalith Oct 2, 2025
dbb2e39
Merge remote-tracking branch 'origin/develop' into 65-grouped-conv-fw…
krithalith Dec 15, 2025
291c6fe
Small post-merge fixes
krithalith Dec 15, 2025
6dd37ab
Fix copyright headers
krithalith Dec 15, 2025
e7170df
Merge branch 'develop' into streamhpc/grouped-conv-fwd-wmma
krithalith Dec 16, 2025
295e899
Remove commented code snippet in gridwise
krithalith Dec 16, 2025
294e14b
Limit the explicit cast added in threadwise_tensor_slice_transfer_v7r…
krithalith Dec 16, 2025
4df4747
Adding tuned instace list for groupoed conv fwd (#3288)
wj-laskowski Dec 16, 2025
0b0aa06
Adding remaining flavors for grouped conv fwd
wj-laskowski Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion example/62_convnd_activ/convinvscale/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,11 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_convnd_activ_xdl_convinvscale)
add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8)
endif()
endif()

# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convinvscale)
add_example_executable(example_convnd_fwd_wmma_convinvscale_fp8 convnd_fwd_wmma_convinvscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convinvscale example_convnd_fwd_wmma_convinvscale_fp8)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "convnd_fwd_convinvscale_common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"

using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvInvscale;

static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvInvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge

#include "run_convnd_fwd_convinvscale_example.inc"

int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;

return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}
16 changes: 16 additions & 0 deletions example/62_convnd_activ/convscale/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,19 @@ if (NOT GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8)
endif()

# WMMA
if (GPU_TARGETS MATCHES "gfx12")
add_custom_target(example_convnd_activ_wmma_convscale)
add_example_executable(example_convnd_fwd_wmma_convscale_fp8 convnd_fwd_wmma_convscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8)

add_example_executable(example_convnd_fwd_wmma_convscale_bf8 convnd_fwd_wmma_convscale_bf8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8)

add_example_executable(example_convnd_fwd_wmma_convscale_fp8_bf8 convnd_fwd_wmma_convscale_fp8_bf8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_fp8_bf8)

add_example_executable(example_convnd_fwd_wmma_convscale_bf8_fp8 convnd_fwd_wmma_convscale_bf8_fp8.cpp)
add_example_dependencies(example_convnd_activ_wmma_convscale example_convnd_fwd_wmma_convscale_bf8_fp8)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "convnd_fwd_convscale_common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"

using InDataType = ck::bf8_t;
using WeiDataType = ck::bf8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = InDataType;
using BComputeDataType = AComputeDataType;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;

static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge

#include "run_convnd_fwd_convscale_example.inc"

int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;

return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "convnd_fwd_convscale_common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_wmma_cshuffle_v3.hpp"

using InDataType = ck::bf8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::bf8_t;
using BComputeDataType = ck::f8_t;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;

static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle_V3<
NDimSpatial, // NDimSpatial
InLayout, // ALayout
WeiLayout, // BLayout
DsLayout, // DsLayout (empty tuple for ConvScale)
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
DsDataType, // DsDataType (empty tuple)
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
64, // BlockSize
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerWmma
16, // NPerWmma
4, // MRepeat
2, // NRepeat
S<4, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMRepeatPerShuffle
1, // CShuffleNRepeatPerShuffle
S<1, 16, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
1, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
AComputeDataType, // AComputeDataType
BComputeDataType, // BComputeDataType
1>; // NumGroupsToMerge

#include "run_convnd_fwd_convscale_example.inc"

int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;

return 0;
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}
Loading