Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def get_gencode_flags(archs):
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
elif cc_val == 103:
arch_code = "103a"
flags += [
"-gencode",
f"arch=compute_{arch_code},code=sm_{arch_code}",
]
else:
flags += ["-gencode", f"arch=compute_{cc_val},code=sm_{cc_val}"]
return flags
Expand Down Expand Up @@ -476,9 +482,10 @@ def find_end_files(directory, end_str):
# of them instead of only the highest one.
has_sm90 = 90 in sm_versions
has_sm100 = 100 in sm_versions and nvcc_version >= 12.9
has_generic_fp8 = not has_sm90 and not has_sm100 # SM89 or other
has_sm103 = 103 in sm_versions and nvcc_version >= 13.0
has_generic_fp8 = not has_sm90 and not has_sm100 and not has_sm103 # SM89 or other

if has_sm90 or has_sm100:
if has_sm90 or has_sm100 or has_sm103:
nvcc_compile_args += [
"-O3",
"-DNDEBUG",
Expand All @@ -501,8 +508,8 @@ def find_end_files(directory, end_str):
"gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm_azp_sm90_int8.cu",
]

if has_sm100:
print("SM100 (Blackwell): Applying SM100 configurations.")
if has_sm100 or has_sm103:
print("SM100 / 103 (Blackwell): Applying SM100 / SM103 configurations.")
# Placeholder for SM100-specific kernel auto-generation scripts
# These might be needed if Blackwell has new FP8 hardware features
# not covered by existing generic CUTLASS templates or SM90 scripts.
Expand Down
Loading