Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def get_gpu_type():

from test_txl import test_softmax
#test_softmax(size=16*1024)
test_softmax(M=32*1024, N=32*1024)
for i in [1, 4, 8, 16, 32]:
test_softmax(M=32*1024, N=i*1024)
164 changes: 164 additions & 0 deletions docker/draw/draw_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np
import matplotlib.pyplot as plt

# 横轴:上下四张图的 context length
ctx = np.array([1024, 2048, 4096, 8192, 16384])
x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,...

# ====================== 示例数据,按需替换 ======================
# FP16, causal = False
fp16_nc_fa3 = np.array([570, 600, 610, 630, 640])
fp16_nc_txl = np.array([484, 544, 578, 597, 608])
fp16_nc_triton= np.array([390, 460, 500, 520, 540])
fp16_nc_tile = np.array([447, 590, 610, 570, 600])
fp16_nc_tk = np.array([453, 597, 599, 610, 590])

# FP16, causal = True
fp16_c_fa3 = np.array([387, 512, 589, 613,622])
fp16_c_txl = np.array([335, 429, 497, 539, 573])
fp16_c_triton = np.array([238, 376, 421, 473, 500])
fp16_c_tile = np.array([346, 436, 511, 478, 509])
fp16_c_tk = np.array([302, 421, 437, 457, 477])

# FP8, causal = False
fp8_nc_fa3 = np.array([587, 770, 850, 900, 980])
fp8_nc_txl = np.array([547, 627, 676, 703, 706])
fp8_nc_triton = np.array([520, 610, 690, 716, 723])

# FP8, causal = True
fp8_c_fa3 = np.array([379, 578, 738, 812, 864])
fp8_c_txl = np.array([343, 485, 591, 639, 664])
fp8_c_triton = np.array([333, 418, 558, 623, 679])
# ===============================================================

methods_up = ["FA3 (CUTLASS)", "Txl", "Triton", "TileLang", "ThunderKittens"]
methods_down = ["FA3 (CUTLASS)", "Txl", "Triton"]
colors = {
"FA3 (CUTLASS)": "#f1c40f",
"Txl": "#e74c3c",
"Triton": "#1abc9c",
"TileLang": "#ff6fb3",
"ThunderKittens": "#3498db",
}

data = {
("FP16, causal=False"): {
"FA3 (CUTLASS)": fp16_nc_fa3,
"Txl": fp16_nc_txl,
"Triton": fp16_nc_triton,
"TileLang": fp16_nc_tile,
"ThunderKittens": fp16_nc_tk,
},
("FP16, causal=True"): {
"FA3 (CUTLASS)": fp16_c_fa3,
"Txl": fp16_c_txl,
"Triton": fp16_c_triton,
"TileLang": fp16_c_tile,
"ThunderKittens": fp16_c_tk,
},
("FP8, causal=False"): {
"FA3 (CUTLASS)": fp8_nc_fa3,
"Txl": fp8_nc_txl,
"Triton": fp8_nc_triton,
},
("FP8, causal=True"): {
"FA3 (CUTLASS)": fp8_c_fa3,
"Txl": fp8_c_txl,
"Triton": fp8_c_triton,
},
}

fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True)

bar_width = 0.16

def plot_panel(ax, title, panel_data, ylim, yticks, methods):
count = len(methods)
offsets = (np.arange(count) - (count - 1) / 2) * bar_width
for i, m in enumerate(methods):
vals = panel_data[m]
if m == "FA3 (CUTLASS)":
ax.bar(
x + offsets[i], # 注意这里用 x,而不是 ctx
vals,
bar_width,
label=m,
color=colors[m],
hatch="//",
)
else:
ax.bar(
x + offsets[i],
vals,
bar_width,
label=m,
color=colors[m],
)
ax.set_title(title)
ax.set_ylim(*ylim)
ax.set_yticks(yticks)
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.set_xticks(x)
ax.set_xticklabels(ctx) # 只把刻度标签写成 1024, 2048...

# 上排:FP16
plot_panel(
axes[0, 0],
"FP16, causal=false",
data[("FP16, causal=False")],
ylim=(0, 800),
yticks=[0, 200, 400, 600, 800],
methods=methods_up,
)
plot_panel(
axes[0, 1],
"FP16, causal=true",
data[("FP16, causal=True")],
ylim=(0, 800),
yticks=[0, 200, 400, 600, 800],
methods=methods_up,
)

# 下排:FP8
plot_panel(
axes[1, 0],
"FP8, causal=false",
data[("FP8, causal=False")],
ylim=(0, 1250),
yticks=[0, 250, 500, 750, 1000],
methods=methods_down,
)
plot_panel(
axes[1, 1],
"FP8, causal=true",
data[("FP8, causal=True")],
ylim=(0, 1250),
yticks=[0, 250, 500, 750, 1000],
methods=methods_down,
)

# axes[0, 0].set_ylabel("Throughput (TFLOPs/s)")
# axes[1, 0].set_ylabel("Throughput (TFLOPs/s)")
# axes[1, 0].set_xlabel("Context length")
# axes[1, 1].set_xlabel("Context length")

handles, labels = axes[0, 0].get_legend_handles_labels()
fig.legend(
handles,
labels,
loc="upper center",
ncol=5,
bbox_to_anchor=(0.5, 0.98),
)

# 全局坐标轴标签,只写一次,自动在整张图居中对齐
fig.supylabel("Throughput (TFLOPs/s)") # 左侧垂直居中
fig.supxlabel("Context length") # 底部水平居中

plt.subplots_adjust(top=0.82, bottom=0.12, left=0.08, right=0.98,
wspace=0.15, hspace=0.35)

# plt.tight_layout()
plt.show()

# batch4-head32-d128
94 changes: 94 additions & 0 deletions docker/draw/draw_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import matplotlib.pyplot as plt

gemm_k = np.array([256, 512, 1024, 2048, 4096, 8192, 16384])

# ------- 示例数据,自行替换 -------
cublas_fp16 = np.array([517, 626, 712, 717, 697, 680, 667])
# tawa_fp16 = np.array([580, 760, 780, 770, 780, 760, 740])
txl_fp16 = np.array([473, 615, 705, 739, 748, 701, 694])
triton_fp16 = np.array([470, 603, 680, 678, 670, 640, 630])
tile_fp16 = np.array([300, 420, 600, 690, 700, 720, 740])
tk_fp16 = np.array([400, 680, 680, 709, 780, 788, 798])

cublas_fp8 = np.array([888, 1203, 1385, 1503, 1561, 1573, 1436])
# tawa_fp8 = np.array([900, 1470, 1600, 1600, 1550, 1500, 1400])
txl_fp8 = np.array([1015, 1308, 1437, 1543, 1565, 1509, 1432])
triton_fp8 = np.array([720, 1212, 1502, 1530, 1535, 1528, 1478])
tile_fp8 = np.array([231, 312, 547, 712, 892, 930, 1003])
tk_fp8 = np.array([579, 860, 1232, 1398, 1497, 1503, 1429])
# -------------------------------

theoretical_fp16 = 1000
theoretical_fp8 = 2000

bar_width = 0.14
x = np.arange(len(gemm_k))

fig, axes = plt.subplots(1, 2, figsize=(12, 3), sharey=False)

# ------------ 左图 FP16 ------------
ax = axes[0]
ax.axhline(theoretical_fp16, color='gray', linewidth=3, label='Theoretical Peak')

ax.bar(x - 2*bar_width, cublas_fp16, bar_width,
label='cuBLAS', color='#f1c40f', hatch='//')
ax.bar(x - 1*bar_width, txl_fp16, bar_width,
label='Txl', color='#e74c3c')
ax.bar(x + 0*bar_width, triton_fp16, bar_width,
label='Triton', color='#1abc9c')
ax.bar(x + 1*bar_width, tile_fp16, bar_width,
label='TileLang', color='#ff6fb3')
ax.bar(x + 2*bar_width, tk_fp16, bar_width,
label='ThunderKittens', color='#3498db')

ax.set_title('FP16')
ax.set_ylabel('Throughput (TFLOPs/s)')
ax.set_xticks(x)
ax.set_xticklabels(gemm_k)
ax.set_xlabel('GEMM K size')
ax.grid(axis='y', linestyle='--', alpha=0.4)

left_ylim = 1200
ax.set_ylim(0, left_ylim)
ax.set_yticks([0, 200, 400, 600, 800, 1000, 1200])

# ------------ 右图 FP8 ------------
ax = axes[1]
ax.axhline(theoretical_fp8, color='gray', linewidth=3)

ax.bar(x - 2*bar_width, cublas_fp8, bar_width,
color='#f1c40f', hatch='//')
ax.bar(x - 1*bar_width, txl_fp8, bar_width,
color='#e74c3c')
ax.bar(x + 0*bar_width, triton_fp8, bar_width,
color='#1abc9c')
ax.bar(x + 1*bar_width, tile_fp8, bar_width,
color='#ff6fb3')
ax.bar(x + 2*bar_width, tk_fp8, bar_width,
color='#3498db')

ax.set_title('FP8')
ax.set_xticks(x)
ax.set_xticklabels(gemm_k)
ax.set_xlabel('GEMM K size')
ax.grid(axis='y', linestyle='--', alpha=0.4)

# 关键:让 2000 所在位置与 1000 对齐
right_ylim = theoretical_fp8 * left_ylim / theoretical_fp16 # = 2400
ax.set_ylim(0, right_ylim)
ax.set_yticks(np.arange(0, 2001, 500)) # 0, 500, 1000, 1500, 2000

# ------------ 顶部 legend & 布局 ------------
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=6,
bbox_to_anchor=(0.5, 0.98))

# plt.subplots_adjust(top=0.78, bottom=0.18, left=0.07, right=0.98, wspace=0.25)
plt.subplots_adjust(top=0.78, bottom=0.18, left=0.08, right=0.98,
wspace=0.15, hspace=0.35)

# plt.tight_layout()
plt.show()

# m=8192, n=8192
121 changes: 121 additions & 0 deletions docker/draw/draw_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np
import matplotlib.pyplot as plt

# 横轴:上下四张图的 context length
ctx = np.array([1024, 2048, 4096, 8192, 16384, 32768])
x = np.arange(len(ctx)) # 真正用于画图的位置:0,1,2,...

# ====================== 示例数据,按需替换 ======================
# FP16, causal = False, sq=1
fp16_nc_fm = np.array([436, 510, 543, 564, 582, 601])
fp16_nc_txl = np.array([401, 490, 518, 535, 538, 561])
fp16_nc_triton= np.array([19, 28, 34, 38, 44, 46])
fp16_nc_tile = np.array([237, 412, 459, 473, 498, 477])
fp16_nc_fi = np.array([406, 491, 527, 532, 528, 552])
# FP16, causal = False, sq=2
fp16_nc2_fm = np.array([521, 591, 621, 628, 579, 626])
fp16_nc2_txl = np.array([475, 531, 554, 557, 565, 587])
fp16_nc2_fi = np.array([486, 534, 535, 545, 536, 539])
# ===============================================================

methods_q1 = ["FlashMLA", "Txl", "Triton", "TileLang", "Flashinfer"]
methods_q2 = ["FlashMLA", "Txl", "Flashinfer"]

colors = {
"FlashMLA": "#f1c40f",
"Txl": "#e74c3c",
"Triton": "#1abc9c",
"TileLang": "#ff6fb3",
"Flashinfer": "#e67e22",
}

data = {
("FP16, causal=False, s_q=1"): {
"FlashMLA": fp16_nc_fm,
"Txl": fp16_nc_txl,
"Triton": fp16_nc_triton,
"TileLang": fp16_nc_tile,
"Flashinfer": fp16_nc_fi,
},
("FP16, causal=False, s_q=2"): {
"FlashMLA": fp16_nc2_fm,
"Txl": fp16_nc2_txl,
"Flashinfer": fp16_nc2_fi,
},
}

# fig, axes = plt.subplots(2, 2, figsize=(12, 4), sharex=True)
fig, ax = plt.subplots(1, 2, figsize=(12, 4), sharex=True)

bar_width = 0.16

def plot_panel(ax, title, panel_data, ylim, yticks, methods):
count = len(methods)
offsets = (np.arange(count) - (count - 1) / 2) * bar_width
for i, m in enumerate(methods):
vals = panel_data[m]
if m == "FlashMLA":
ax.bar(
x + offsets[i], # 注意这里用 x,而不是 ctx
vals,
bar_width,
label=m,
color=colors[m],
hatch="//",
)
else:
ax.bar(
x + offsets[i],
vals,
bar_width,
label=m,
color=colors[m],
)
ax.set_title(title)
ax.set_ylim(*ylim)
ax.set_yticks(yticks)
ax.grid(axis="y", linestyle="--", alpha=0.4)
ax.set_xticks(x)
ax.set_xticklabels(ctx) # 只把刻度标签写成 1024, 2048...

# 上排:FP16
plot_panel(
ax[0],
"FP16, causal=false, s_q=1",
data[("FP16, causal=False, s_q=1")],
ylim=(0, 800),
yticks=[0, 200, 400, 600, 800],
methods=methods_q1,
)

plot_panel(
ax[1],
"FP16, causal=false, s_q=2",
data[("FP16, causal=False, s_q=2")],
ylim=(0, 800),
yticks=[0, 200, 400, 600, 800],
methods=methods_q2,
)
# axes[0, 0].set_ylabel("Throughput (TFLOPs/s)")
# axes[1, 0].set_ylabel("Throughput (TFLOPs/s)")
# axes[1, 0].set_xlabel("Context length")
# axes[1, 1].set_xlabel("Context length")

handles, labels = ax[0].get_legend_handles_labels()
fig.legend(
handles,
labels,
loc="upper center",
ncol=5,
bbox_to_anchor=(0.5, 0.98),
)

ax[0].set_ylabel("Throughput (TFLOPs/s)")
ax[0].set_xlabel("Context length")
ax[1].set_xlabel("Context length")

plt.subplots_adjust(top=0.82)

plt.show()

# b=132, s_q=2, h_q=128, h_kv=1, d=576, dv=512, causal=False, dtype=torch.float16
Loading