From aa46fd3a88d004dd72d299c48d8fb5fecc280ba1 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Thu, 4 Dec 2025 11:15:38 +0800 Subject: [PATCH 1/6] add endtoend test --- docker/end_to_end/LICENSE | 21 ++ docker/end_to_end/README.md | 78 +++++ docker/end_to_end/bench.py | 156 +++++++++ docker/end_to_end/gpt.py | 381 +++++++++++++++++++++ docker/end_to_end/kernels.py | 525 +++++++++++++++++++++++++++++ docker/end_to_end/kernels_attn.py | 437 ++++++++++++++++++++++++ docker/end_to_end/kernels_ffn.py | 278 +++++++++++++++ docker/end_to_end/modal_test.py | 123 +++++++ docker/end_to_end/requirements.txt | 6 + docker/end_to_end/test.py | 152 +++++++++ 10 files changed, 2157 insertions(+) create mode 100644 docker/end_to_end/LICENSE create mode 100644 docker/end_to_end/README.md create mode 100644 docker/end_to_end/bench.py create mode 100644 docker/end_to_end/gpt.py create mode 100644 docker/end_to_end/kernels.py create mode 100644 docker/end_to_end/kernels_attn.py create mode 100644 docker/end_to_end/kernels_ffn.py create mode 100644 docker/end_to_end/modal_test.py create mode 100644 docker/end_to_end/requirements.txt create mode 100644 docker/end_to_end/test.py diff --git a/docker/end_to_end/LICENSE b/docker/end_to_end/LICENSE new file mode 100644 index 0000000..bb9379b --- /dev/null +++ b/docker/end_to_end/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Vasudev Gupta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/docker/end_to_end/README.md b/docker/end_to_end/README.md new file mode 100644 index 0000000..9204506 --- /dev/null +++ b/docker/end_to_end/README.md @@ -0,0 +1,78 @@ +Triton implementation of GPT/LLAMA models. Objective of this project is to understand how much performance can be squeezed out if we implement full-GPT-block in one triton kernel. + +**Performance** + +triton implementation is more fast & memory efficient compared to HuggingFace Transformers implementation. + +```bash +python3 bench.py +``` + +**Latency** + +| precision | HuggingFace GPT | Triton GPT | +|------------------------|-----------------|------------| +| fp32 | 1800 ms | - | +| tf32 | 631.35 ms | 462.63 ms | +| mixed precision (fp16) | 510.80 ms | 273 ms | +| fp16 | 301.92 ms | - | + +_time taken to process batch size - 512x300 on 1 A100 40 GB_ + + **Max Batch Size** + +| | max batch size | +|------------------------|----------------| +| HuggingFace GPT | 1024 | +| Triton GPT | 2048 | + +_I considered batch sizes with power of 2 only. Both runs had seqlen=300 and mixed precision was enabled._ + +**MFU** + +```python +from gpt import compute_mfu +# fwd MFU + +# HuggingFace GPT (fp16) +compute_mfu(2 * 124 * 10**6 * 512*512 / 0.302, gpu="h100") +# 21.76% + +# HuggingFace GPT (mixed precision) +compute_mfu(2 * 124 * 10**6 * 512*512 / 0.510, gpu="h100") +# 12.88% + +# triton (mixed precision) +compute_mfu(2 * 124 * 10**6 * 512*512 / 0.273, gpu="h100") +# 24.07% +``` + +**Supported Features** +* [x] fused implementation of several components of GPT block (for eg: `dropout(wte(x) + wpe(x))`, `dropout(wx + b)`, `gelu(wx + b)`) +* [x] flash attention v1 algorithm +* [x] GPT2 implementation in triton +* [x] support for loading pre-trained weights of huggingface-gpt2 +* [ ] support KV cache & sampling for inference loop +* [ ] implement back-propogation of GPT block in triton (i.e. solving the math problem) +* [ ] implement paged-attention from vLLM project in triton +* [ ] implement flash attention v2 & v3 +* [ ] add kernels for LLAMA-3.1 +* [ ] implement adamw in triton (with FSDP-stage2 support) + +**Installation** + +```bash +pip3 install -r requirements.txt +# `numpy<2` is hard-requirement for running on CPU +# else triton gives garbage - likely some bug in triton +``` + +**Running tests** + +```python +# you can run following command on CPU +TRITON_INTERPRET=1 pytest -sv test.py + +# you can run following command on GPU +pytest -sv test.py +``` diff --git a/docker/end_to_end/bench.py b/docker/end_to_end/bench.py new file mode 100644 index 0000000..fafea6d --- /dev/null +++ b/docker/end_to_end/bench.py @@ -0,0 +1,156 @@ +# TRITON_INTERPRET=1 python3 bench.py + +import torch +import triton +from transformers import AutoTokenizer +from gpt import convert_hf_and_load_model, load_model_only + +STRING = """\ +Large language models have been shown to achieve remarkable performance across a variety of natural\ +language tasks using few-shot learning, which drastically reduces the number of task-specific training\ +examples needed to adapt the model to a particular application. To further our understanding of the\ +impact of scale on few-shot learning, we trained a 540-billion parameter, densely activated, Transformer\ +language model, which we call Pathways Language Model (PaLM).\ +We trained PaLM on 6144 TPU v4 chips using Pathways, a new ML system which enables highly efficient\ +training across multiple TPU Pods. We demonstrate continued benefits of scaling by achieving state-ofthe-art few-shot learning results on hundreds of language understanding and generation benchmarks. On a\ +number of these tasks, PaLM 540B achieves breakthrough performance, outperforming the finetuned stateof-the-art on a suite of multi-step reasoning tasks, and outperforming average human performance on the\ +recently released BIG-bench benchmark. A significant number of BIG-bench tasks showed discontinuous\ +improvements from model scale, meaning that performance steeply increased as we scaled to our largest\ +model. PaLM also has strong capabilities in multilingual tasks and source code generation, which we\ +demonstrate on a wide array of benchmarks. We additionally provide a comprehensive analysis on bias\ +and toxicity, and study the extent of training data memorization with respect to model scale. Finally,\ +we discuss the ethical considerations related to large language models and discuss potential mitigation\ +strategies.\ +""" + +def run_benchmark(provider, warmup=25, rep=100, mixed_precison=True, max_length=1024, need_gpt=True): + assert torch.cuda.is_available() + device = "cuda" + model_id = "gpt2" + if need_gpt: + model, hf_model, txl_model = convert_hf_and_load_model(model_id, device) + else: + model, txl_model = load_model_only(device) + if mixed_precison: + model.to(torch.float16) + txl_model.to(torch.float16) + # hf_model.to(torch.float16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + # triton is slow for batch_size = 1 with current settings but much faster with batch > 1 + inputs = tokenizer([STRING*1000] * 32, return_tensors="pt", max_length=max_length, truncation=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + # print(inputs) + print("input_ids shape:", inputs["input_ids"].shape) + # exit() + with torch.no_grad(): + # z_torch = hf_model(**inputs).last_hidden_state + # z = model(inputs["input_ids"]) + # z_txl = txl_model(inputs["input_ids"]) + # print(f"txl output sample: {z_txl[0, :5, :5]}") + # print(f"hf_model output sample: {z_torch[0, :5, :5]}") + # print(f"model output sample: {z[0, :5, :5]}") + # max_diff = torch.abs(z_torch - z).max().item() + # max_diff_txl = torch.abs(z_txl - z).max().item() + # max_diff_txl_hf = torch.abs(z_txl - z_torch).max().item() + # print(f"max diff between hf_model and model: {max_diff}") + # print(f"max diff between txl_model and model: {max_diff_txl}") + # print(f"max diff between txl_model and hf_model: {max_diff_txl_hf}") + # max_diff_location = torch.abs(z_torch - z).argmax().item() + # print(f"model output around max diff location: {z.view(-1)[max_diff_location-5:max_diff_location+5]}") + # print(f"hf_model output around max diff location: {z_torch.view(-1)[max_diff_location-5:max_diff_location+5]}") + # print(f"txl_model output around max diff location: {z_txl.view(-1)[max_diff_location-5:max_diff_location+5]}") + # max_diff_location_txl = torch.abs(z_txl - z_torch).argmax().item() + # print(f"model output around max diff location (txl): {z.view(-1)[max_diff_location_txl-5:max_diff_location_txl+5]}") + # print(f"txl_model output around max diff location: {z_txl.view(-1)[max_diff_location_txl-5:max_diff_location_txl+5]}") + # print(f"hf_model output around max diff location (txl): {z_torch.view(-1)[max_diff_location_txl-5:max_diff_location_txl+5]}") + if provider == "torch": + def fn(): + if mixed_precison: + with torch.autocast(device_type="cuda", dtype=torch.float16): + return hf_model(**inputs).last_hidden_state + else: + return hf_model(**inputs).last_hidden_state + return triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "triton": + fn = lambda: model(inputs["input_ids"]) + return triton.testing.do_bench(fn, warmup=warmup, rep=rep) + if provider == "txl": + fn = lambda: txl_model(inputs["input_ids"]) + return triton.testing.do_bench(fn, warmup=warmup, rep=rep) + +def validate(mixed_precison=True, max_length=1024): + assert torch.cuda.is_available() + device = "cuda" + model_id = "gpt2" + model, hf_model, txl_model = convert_hf_and_load_model(model_id, device) + if mixed_precison: + model.to(torch.float16) + txl_model.to(torch.float16) + # hf_model.to(torch.float16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + # triton is slow for batch_size = 1 with current settings but much faster with batch > 1 + inputs = tokenizer([STRING*1000] * 32, return_tensors="pt", max_length=max_length, truncation=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + print(inputs) + print("input_ids shape:", inputs["input_ids"].shape) + # exit() + with torch.no_grad(): + z_torch = hf_model(**inputs).last_hidden_state + z = model(inputs["input_ids"]) + z_txl = txl_model(inputs["input_ids"]) + print(f"txl output sample: {z_txl[0, :5, :5]}") + print(f"hf_model output sample: {z_torch[0, :5, :5]}") + print(f"triton_model output sample: {z[0, :5, :5]}") + mean_diff_triton = torch.abs(z_torch - z).mean().item() + mean_diff_txl = torch.abs(z_torch - z_txl).mean().item() + print(f"mean diff between triton_model and hf_model: {mean_diff_triton}") + print(f"mean diff between txl_model and hf_model: {mean_diff_txl}") + +# 1 A100 40 GB +# torch: batch_size = 512 && t = 1801.32 +# triton: batch_size = 512 && t = 789.14 +# torch: batch_size = 1024 && OOM +# triton: batch_size = 2048 && t = 3153.70 + +print("start") +validate(mixed_precison=True, max_length=1024) +print("triton:", run_benchmark("triton")) +print("txl:", run_benchmark("txl")) +print("torch:", run_benchmark("torch")) +print("triton long:", run_benchmark("triton", max_length=4096, need_gpt=False)) +print("txl long:", run_benchmark("txl", max_length=4096, need_gpt=False)) +# print("txl test:", run_benchmark("txl", max_length=1024, need_gpt=False)) +# print("triton test:", run_benchmark("triton", max_length=1024, need_gpt=False)) +# OLD SUMMARY +# fp32 +# torch: 1800 +# triton: 789.14 + +# mixed precision +# torch: 510.80 +# triton: 429.80 + +# fp16 +# torch: 301.92 + +# triton with mixed precison = False +# ffn cast enabled: 791.13 +# flash cast enabled: 759.71 +# num_warps = 8 & BLOCK_SIZE = 64 ffn :: 759.18 +# num_warps = 8 & BLOCK_SIZE = 128 ffn :: 463.80 +# layer norm BLOCK_SIZE = 32768 :: 832.63 +# layer norm BLOCK_SIZE = 512 :: 462.61 +# embeddings BLOCK_SIZE = 512 :: 462.87 +# attention BLOCK_SIZE = 128 & num_stages = 4 :: 1279.38 +# attention BLOCK_SIZE = 128 & num_stages = 8 :: 460.27 +# final config: embeddings (512, 4) + layer norm (512, 4) + ffn (128, 128, 64, 8) + attention (128, 8) + +# mixed precision = True +# triton: 273.61 +# with attention (128, 8), t = 900 but with attention (64, 4), t = 273! + +# mixed precision = False +# torch.backends.cuda.matmul.allow_tf32 = True +# torch.backends.cudnn.allow_tf32 = True +# torch: 623.3262329101562 + diff --git a/docker/end_to_end/gpt.py b/docker/end_to_end/gpt.py new file mode 100644 index 0000000..7784ebb --- /dev/null +++ b/docker/end_to_end/gpt.py @@ -0,0 +1,381 @@ +# TRITON_INTERPRET=1 python3 gpt.py + +from dataclasses import dataclass + +import torch +import torch.nn as nn +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers import GPT2Model as HFGPT2 + +from kernels import (flash_attention_v1, fused_embeddings, fused_ffn, + fused_layer_norm, matmul_and_split_qkv) +from kernels_attn import flash_attention_v2 +from kernels_ffn import fused_ffn_txl + +GPU_TO_FLOPS = { + "v100": 130 * 10**12, + "a100": 312 * 10**12, + "h100": 989 * 10**12, +} + + +class FusedAttention_v2(nn.Module): + def __init__(self, hidden_size, num_heads, dropout_prob=0.0): + super().__init__() + self.dropout_prob = dropout_prob + self.num_heads = num_heads + + self.hidden_size = hidden_size + + self.layer_norm_weight = nn.Parameter(torch.ones(hidden_size)) + self.layer_norm_bias = nn.Parameter(torch.zeros(hidden_size)) + + self.c_attn_weight = nn.Parameter(torch.rand(hidden_size, 3 * hidden_size)) + self.c_attn_bias = nn.Parameter(torch.rand(3 * hidden_size)) + + self.c_proj_weight = nn.Parameter(torch.rand(hidden_size, hidden_size)) + self.c_proj_bias = nn.Parameter(torch.rand(hidden_size)) + + def forward(self, x): + residual = x + x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) + q, k, v = matmul_and_split_qkv( + x, self.c_attn_weight.data, self.c_attn_bias.data, self.num_heads + ) + dropout_prob = self.dropout_prob if self.training else 0.0 + x = flash_attention_v2( + q, + k, + v, + dropout_prob=dropout_prob, + ) + x = x.transpose(1, 2).contiguous().view(residual.shape) + x = fused_ffn_txl( + x, + self.c_proj_weight.data, + bias=self.c_proj_bias.data, + residual=residual, + add_gelu=False, + dropout_prob=dropout_prob, + ) + return x + + def get_fwd_flops(self, num_tokens): + h = self.hidden_size + layer_norm = num_tokens * h + num_tokens * h + c_attn = num_tokens * (3 * h) * (2 * h) + num_tokens * (3 * h) + c_proj = num_tokens * h * (2 * h) + num_tokens * h + return layer_norm + c_attn + c_proj + +class FusedAttention_v1(nn.Module): + def __init__(self, hidden_size, num_heads, dropout_prob=0.0): + super().__init__() + self.dropout_prob = dropout_prob + self.num_heads = num_heads + + self.hidden_size = hidden_size + + self.layer_norm_weight = nn.Parameter(torch.ones(hidden_size)) + self.layer_norm_bias = nn.Parameter(torch.zeros(hidden_size)) + + self.c_attn_weight = nn.Parameter(torch.rand(hidden_size, 3 * hidden_size)) + self.c_attn_bias = nn.Parameter(torch.rand(3 * hidden_size)) + + self.c_proj_weight = nn.Parameter(torch.rand(hidden_size, hidden_size)) + self.c_proj_bias = nn.Parameter(torch.rand(hidden_size)) + + def forward(self, x): + residual = x + x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) + q, k, v = matmul_and_split_qkv( + x, self.c_attn_weight.data, self.c_attn_bias.data, self.num_heads + ) + dropout_prob = self.dropout_prob if self.training else 0.0 + x = flash_attention_v1( + q, + k, + v, + dropout_prob=dropout_prob, + ) + x = x.transpose(1, 2).contiguous().view(residual.shape) + x = fused_ffn( + x, + self.c_proj_weight.data, + bias=self.c_proj_bias.data, + residual=residual, + add_gelu=False, + dropout_prob=dropout_prob, + ) + return x + + def get_fwd_flops(self, num_tokens): + h = self.hidden_size + layer_norm = num_tokens * h + num_tokens * h + c_attn = num_tokens * (3 * h) * (2 * h) + num_tokens * (3 * h) + c_proj = num_tokens * h * (2 * h) + num_tokens * h + return layer_norm + c_attn + c_proj + +class FusedMLP_txl(nn.Module): + def __init__(self, hidden_size, dropout_prob=0.0): + super().__init__() + + self.dropout_prob = dropout_prob + + self.layer_norm_weight = nn.Parameter(torch.ones((hidden_size,))) + self.layer_norm_bias = nn.Parameter(torch.zeros((hidden_size,))) + + intermediate_size = 4 * hidden_size + + self.ffn1_weight = nn.Parameter(torch.rand(hidden_size, intermediate_size)) + self.ffn1_bias = nn.Parameter(torch.rand(intermediate_size)) + + self.ffn2_weight = nn.Parameter(torch.rand(intermediate_size, hidden_size)) + self.ffn2_bias = nn.Parameter(torch.rand(hidden_size)) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + def forward(self, x): + # mlp = DROPOUT(GELU(LN(X) @ A + a) @ B + b) + X + dropout_prob = self.dropout_prob if self.training else 0.0 + residual = x + x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) + x = fused_ffn_txl( + x, + self.ffn1_weight.data, + bias=self.ffn1_bias.data, + residual=None, + add_gelu=True, + dropout_prob=dropout_prob, + ) + x = fused_ffn_txl( + x, + self.ffn2_weight.data, + bias=self.ffn2_bias.data, + residual=residual, + add_gelu=False, + dropout_prob=dropout_prob, + ) + return x + + def get_fwd_flops(self, num_tokens): + h = self.hidden_size + mid = self.intermediate_size + layer_norm = num_tokens * h + num_tokens * h + ffn1 = num_tokens * mid * (2 * h) + num_tokens * mid + ffn2 = num_tokens * h * (2 * mid) + num_tokens * h + return layer_norm + ffn1 + ffn2 + +class FusedMLP_triton(nn.Module): + def __init__(self, hidden_size, dropout_prob=0.0): + super().__init__() + + self.dropout_prob = dropout_prob + + self.layer_norm_weight = nn.Parameter(torch.ones((hidden_size,))) + self.layer_norm_bias = nn.Parameter(torch.zeros((hidden_size,))) + + intermediate_size = 4 * hidden_size + + self.ffn1_weight = nn.Parameter(torch.rand(hidden_size, intermediate_size)) + self.ffn1_bias = nn.Parameter(torch.rand(intermediate_size)) + + self.ffn2_weight = nn.Parameter(torch.rand(intermediate_size, hidden_size)) + self.ffn2_bias = nn.Parameter(torch.rand(hidden_size)) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + def forward(self, x): + # mlp = DROPOUT(GELU(LN(X) @ A + a) @ B + b) + X + dropout_prob = self.dropout_prob if self.training else 0.0 + residual = x + x = fused_layer_norm(x, self.layer_norm_weight.data, self.layer_norm_bias.data) + x = fused_ffn( + x, + self.ffn1_weight.data, + bias=self.ffn1_bias.data, + residual=None, + add_gelu=True, + dropout_prob=dropout_prob, + ) + x = fused_ffn( + x, + self.ffn2_weight.data, + bias=self.ffn2_bias.data, + residual=residual, + add_gelu=False, + dropout_prob=dropout_prob, + ) + return x + + def get_fwd_flops(self, num_tokens): + h = self.hidden_size + mid = self.intermediate_size + layer_norm = num_tokens * h + num_tokens * h + ffn1 = num_tokens * mid * (2 * h) + num_tokens * mid + ffn2 = num_tokens * h * (2 * mid) + num_tokens * h + return layer_norm + ffn1 + ffn2 + +@dataclass +class GPTConfig: + vocab_size: int = 50304 + block_size: int = 512 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.1 + + +class FusedGPT(nn.Module): + def __init__(self, config, txl=False): + super().__init__() + self.config = config + + self.wte_weight = nn.Parameter(torch.rand(config.vocab_size, config.n_embd)) + self.wpe_weight = nn.Parameter(torch.rand(config.block_size, config.n_embd)) + FusedAttention = FusedAttention_v2 if txl else FusedAttention_v1 + FusedMLP = FusedMLP_txl if txl else FusedMLP_triton + + self.blocks = nn.ModuleList( + [ + nn.Sequential( + FusedAttention( + config.n_embd, + config.n_head, + dropout_prob=config.dropout, + ), + FusedMLP( + config.n_embd, + dropout_prob=config.dropout, + ), + ) + for _ in range(config.n_layer) + ] + ) + self.layer_norm_weight = nn.Parameter(torch.ones((config.n_embd,))) + self.layer_norm_bias = nn.Parameter(torch.zeros((config.n_embd,))) + + # TODO: we don't wanna consume consume 2x memory here because of transpose and contiguous + # instead implement transposed matmul in triton kernel + # self.lm_head_weight = self.wte.weight.data.T.contiguous() + + def forward(self, x): + # it does causal automatically, no need of separate attention/padding mask + dropout_prob = self.config.dropout_prob if self.training else 0.0 + x = fused_embeddings( + x, self.wte_weight.data, self.wpe_weight.data, dropout_prob=dropout_prob + ) + for block in self.blocks: + x = block(x) + x = fused_layer_norm(x, self.layer_norm_weight, self.layer_norm_bias) + # x = fused_ffn( + # x, + # self.lm_head_weight, + # bias=None, + # residual=None, + # add_gelu=False, + # dropout_prob=0.0, + # ) + return x + + def get_fwd_flops(self, num_tokens): + h = self.config.n_embd + v = self.config.vocab_size + p = self.config.block_size + wte = num_tokens * h * (2 * v) + wpe = num_tokens * h * (2 * p) + blocks = sum( + [ + module.get_fwd_flops(num_tokens) + for block in self.blocks + for module in block + ] + ) + layer_norm = num_tokens * h + num_tokens * h + return blocks + layer_norm + wte + wpe + + +def convert_huggingface_to_triton(hf_sd, hf_config): + config = GPTConfig( + vocab_size=hf_config.vocab_size, + block_size=hf_config.n_ctx, + n_layer=hf_config.n_layer, + n_head=hf_config.n_head, + n_embd=hf_config.n_embd, + dropout=0.1, + ) + mapping = { + "wte.weight": "wte_weight", + "wpe.weight": "wpe_weight", + "ln_f.weight": "layer_norm_weight", + "ln_f.bias": "layer_norm_bias", + } + block = { + "h.{i}.ln_1.weight": "blocks.{i}.0.layer_norm_weight", + "h.{i}.ln_1.bias": "blocks.{i}.0.layer_norm_bias", + "h.{i}.attn.bias": None, + "h.{i}.attn.c_attn.weight": "blocks.{i}.0.c_attn_weight", + "h.{i}.attn.c_attn.bias": "blocks.{i}.0.c_attn_bias", + "h.{i}.attn.c_proj.weight": "blocks.{i}.0.c_proj_weight", + "h.{i}.attn.c_proj.bias": "blocks.{i}.0.c_proj_bias", + "h.{i}.ln_2.weight": "blocks.{i}.1.layer_norm_weight", + "h.{i}.ln_2.bias": "blocks.{i}.1.layer_norm_bias", + "h.{i}.mlp.c_fc.weight": "blocks.{i}.1.ffn1_weight", + "h.{i}.mlp.c_fc.bias": "blocks.{i}.1.ffn1_bias", + "h.{i}.mlp.c_proj.weight": "blocks.{i}.1.ffn2_weight", + "h.{i}.mlp.c_proj.bias": "blocks.{i}.1.ffn2_bias", + } + for k, v in block.items(): + if v is None: + continue + for i in range(config.n_layer): + mapping[k.format(i=i)] = v.format(i=i) + sd = {} + for k, v in tqdm(hf_sd.items()): + sd[mapping[k]] = v + return sd, config + + +def convert_hf_and_load_model(model_id, device): + hf_model = HFGPT2.from_pretrained(model_id) + state_dict, config = convert_huggingface_to_triton( + hf_model.state_dict(), hf_model.config + ) + # print(hf_model.config) + model = FusedGPT(config) + txl_model = FusedGPT(config, txl=True) + model.load_state_dict(state_dict) + txl_model.load_state_dict(state_dict) + print("Model loaded.") + return model.to(device).eval(), hf_model.to(device).eval(), txl_model.to(device).eval() + +def load_model_only(device): + config = GPTConfig( + vocab_size=50304, + block_size=16384, + n_layer=24, + n_head=12, + n_embd=768, + dropout=0.0, + ) + txl_model = FusedGPT(config, txl=True) + model = FusedGPT(config) + print("Model created.") + return model.to(device).eval(), txl_model.to(device).eval() + +def estimate_days(flops, mfu=0.45, gpu="h100", num_gpus=1): + # its probably very hard to achieve 0.45 mfu - LOL + # but thats kinda SOTA in papers from top labs + assert gpu in GPU_TO_FLOPS + return flops / (mfu * GPU_TO_FLOPS[gpu] * 3600 * 24 * num_gpus) + + +def get_num_parameters(model): + return sum([p.numel() for p in model.parameters()]) + + +def compute_mfu(flops_per_second, gpu="h100"): + assert gpu in GPU_TO_FLOPS + return flops_per_second / GPU_TO_FLOPS[gpu] diff --git a/docker/end_to_end/kernels.py b/docker/end_to_end/kernels.py new file mode 100644 index 0000000..2a09e33 --- /dev/null +++ b/docker/end_to_end/kernels.py @@ -0,0 +1,525 @@ +import math + +import torch +import triton +import triton.language as tl + +# torch becomes 3x faster with following lines for fp32 +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +# TODO: shift to `make_block_ptr`? + + +# tl.math.tanh doesn't exist in CPU version of triton +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def gelu_new(x): + pi = math.pi + a = tl.math.sqrt(2.0 / pi) + b = x + 0.044715 * x * x * x + return 0.5 * x * (1.0 + tanh(a * b)) + + +# TODO: fixed seed would hurt the performance +# but how do we modify seed design wise? +@triton.jit +def dropout(x, p, seed, offset): + random = tl.rand(seed, offset) + return tl.where(random > p, x / (1 - p), 0.0) + + +@triton.jit +def fused_embeddings_kernel( + x_ptr, + wte_ptr, + wpe_ptr, + z_ptr, + B, + L, + V, + P, + H, + dropout_prob=0.0, + seed=1337, + BLOCK_SIZE: tl.constexpr = 512, +): + # f = dropout(wte(x) + wpe(x)) + + # x: (B*S,) + # wte: (V, H) + # wpe: (P, H) + # z: (B*S, H) + + pid = tl.program_id(0) + wte_ptr += tl.load(x_ptr + pid) * H + wpe_ptr += (pid % L) * H + z_ptr += pid * H + + for k in range(0, H, BLOCK_SIZE): + offset = k + tl.arange(0, BLOCK_SIZE) + mask = offset < H + + z = tl.load(wte_ptr + offset, mask=mask, other=0.0) + z += tl.load(wpe_ptr + offset, mask=mask, other=0.0) + z = dropout(z, dropout_prob, seed, offset) + + tl.store(z_ptr + offset, z, mask=mask) + + +@torch.no_grad() +def fused_embeddings(x, wte, wpe, dropout_prob=0.0): + # x: (batch_size, seqlen) + # wte: (vocab_size, hidden_size) + # wpe: (block_size, hidden_size) + assert wte.shape[1] == wpe.shape[1] + assert x.is_contiguous() + assert wte.is_contiguous() + assert wpe.is_contiguous() + B, L = x.shape + V, H = wte.shape + P = wpe.shape[0] + z = torch.empty((B * L, H), device=x.device, dtype=wte.dtype) + grid = (z.shape[0],) + fused_embeddings_kernel[grid]( + x.view(-1), + wte, + wpe, + z, + B, + L, + V, + P, + H, + dropout_prob=dropout_prob, + ) + return z.view((B, L, H)) + + +@triton.jit +def fused_layer_norm_kernel( + x_ptr, w_ptr, b_ptr, z_ptr, H, eps=1e-5, BLOCK_SIZE: tl.constexpr = 512 +): + # f = ((x - mean) / (std + eps)) * w + b + # x: (M, H) + # launch with 1D grid along M direction + + row_id = tl.program_id(0) + x_ptr += row_id * H + z_ptr += row_id * H + + x_mean = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for i in range(0, H, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offset, mask=(offset < H), other=0.0) + x_mean += x.to(tl.float32) + x_mean = tl.sum(x_mean) / H + + x_var = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for i in range(0, H, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offset, mask=(offset < H), other=x_mean) + x = x.to(tl.float32) + x_var += (x - x_mean) * (x - x_mean) + x_var = tl.sum(x_var) / H + rstd = 1 / tl.sqrt(x_var + eps) + + # TODO: we could prevent this extra loop if we fuse it in ffn block? + # but thats quite hacky - so, lets move with extra loop for now + for i in range(0, H, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < H + + x = tl.load(x_ptr + offset, mask=mask, other=0.0) + w = tl.load(w_ptr + offset, mask=mask, other=0.0) + b = tl.load(b_ptr + offset, mask=mask, other=0.0) + + z = (x - x_mean) * rstd + z = z * w + b + + tl.store(z_ptr + offset, z, mask=mask) + + +@torch.no_grad() +def fused_layer_norm(x, weight, bias): + # x: (*, hidden_size) + # weight: (hidden_size,) + # bias: (hidden_size,) + assert x.is_contiguous() + assert weight.is_contiguous() + assert bias.is_contiguous() + assert weight.shape == bias.shape + assert x.shape[-1] == weight.shape[0] + out_shape = x.shape + x = x.view((-1, x.shape[-1])) + B, H = x.shape + x = x.view((B, H)) + z = torch.empty(x.shape, device=x.device, dtype=x.dtype) + fused_layer_norm_kernel[(B,)](x, weight, bias, z, H) + return z.view(out_shape) + + +# TODO: implement grouping for extra 10% speedup +# also, need to understand what's gemm matmul +@triton.jit +def fused_ffn_kernel( + x_ptr, + w_ptr, + z_ptr, + M, + N, + K, + b_ptr=None, + r_ptr=None, + apply_gelu=False, + dropout_prob=0.0, + seed=1337, + BLOCK_SIZE_M: tl.constexpr = 128, + BLOCK_SIZE_N: tl.constexpr = 128, + BLOCK_SIZE_K: tl.constexpr = 64, +): + # f = dropout(gelu(x @ w + b)) + residual + # launch with 2D grid of blocks along M & N directions + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + # intuition is this: In normal math, we basically take 1 row of X & 1 column of W + # and just multiply element wise and add stuff + # but here we add multiple consecutive rows of X & multiple consecutive rows of W + # and do dot product basically + + # pid_m: vertical + # pid_n: horizontal + + # we basically move over output matrix and computes each block in each kernel + + # x: (M, K) + # w: (K, N) + # b: (N,) + # z: (M, N) + + # x block size: (BLOCK_SIZE_M, BLOCK_SIZE_K) + # w block size: (BLOCK_SIZE_K, BLOCK_SIZE_N) + # z block size: (BLOCK_SIZE_M, BLOCK_SIZE_N) + + # these are the pointer of 1st element for each block in output matrix + + # we basically add row-block-shift here + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] + + # we basically add column-block-shift here + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)[None, :] + + # each block in z would be of shape-(M, N) + # block of size: BLOCK_SIZE_M x BLOCK_SIZE_K would move in horizontal direction + # block of size: BLOCK_SIZE_K x BLOCK_SIZE_N would move in vertical direction + + # we need this loop because we might not be able to fit full row of X & full column of W in-memory + z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + x_k = tl.arange(0, BLOCK_SIZE_K)[None, :] + k + x = tl.load(x_ptr + offs_m * K + x_k, mask=(offs_m < M) & (x_k < K), other=0.0) + # TODO: need to read why casting to fp16 is important here + x = x.to(tl.float16) + # (BLOCK_SIZE_M, BLOCK_SIZE_K) + + w_k = tl.arange(0, BLOCK_SIZE_K)[:, None] + k + w = tl.load(w_ptr + w_k * N + offs_n, mask=(w_k < K) & (offs_n < N), other=0.0) + w = w.to(tl.float16) + # (BLOCK_SIZE_K, BLOCK_SIZE_N) + + z = tl.dot(x, w, acc=z) + # (BLOCK_SIZE_M, BLOCK_SIZE_N) + + if b_ptr is not None: + b = tl.load(b_ptr + offs_n, mask=(offs_n < N), other=0.0) + z += b.to(tl.float32) + # (1, BLOCK_SIZE_N) + + z_offset = offs_m * N + offs_n + z_mask = (offs_m < M) & (offs_n < N) + + if apply_gelu: + z = gelu_new(z) + if dropout_prob > 0.0: + z = dropout(z, dropout_prob, seed, z_offset) + + if r_ptr is not None: + r = tl.load(r_ptr + z_offset, mask=z_mask) + z += r.to(tl.float32) + + tl.store(z_ptr + z_offset, z, mask=z_mask) + + +@torch.no_grad() +def fused_ffn( + x, + weight, + bias=None, + residual=None, + add_gelu=False, + dropout_prob=0.0, +): + # x: (*, K) + # weight: (K, N) + # bias: (N,) + # f = dropout(gelu(x @ w + b)) + residual + + out_shape_0 = x.shape[:-1] + x = x.view((-1, x.shape[-1])) + + M, K = x.shape + N = weight.shape[1] + + x = x.view((M, K)) + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + + assert x.is_contiguous() + assert weight.is_contiguous() + assert x.shape[1] == weight.shape[0] + if bias is not None: + assert bias.is_contiguous() + assert weight.shape[1] == bias.shape[0] + if residual is not None: + residual = residual.view(z.shape) + assert residual.is_contiguous() + + # (128, 128, 64) leads to 6x slowdown with num_stages == 4 + # while its 40% faster with num_stages = 8 + BLOCK_SIZE_M = 128 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = 64 + grid = (triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N), 1) + fused_ffn_kernel[grid]( + x, + weight, + z, + M, + N, + K, + apply_gelu=add_gelu, + dropout_prob=dropout_prob, + b_ptr=bias, + r_ptr=residual, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + num_warps=8, + ) + return z.view((*out_shape_0, N)) + + +# @triton.jit +# def softmax_kernel(x_ptr, z_ptr, L, N1, H, BLOCK_SIZE_L: tl.constexpr, B1: tl.constexpr): +# # x: (L, H) +# # out: (L, H) +# pid_0 = tl.program_id(0) +# x_ptr += pid_0 * H +# z_ptr += pid_0 * H +# max_value, denominator = 0., 0. +# for i in range(0, H, B1): +# offset = tl.arange(i, i + B1) +# x = tl.load(x_ptr + offset, mask=offset < H, other=0) +# block_max_value = tl.max(x, keep_dims=True) +# new_max_value = tl.where( +# block_max_value > max_value, block_max_value, max_value +# ) +# x = tl.exp(x - new_max_value) +# denominator = denominator / tl.exp(new_max_value - max_value) +# denominator += tl.sum(x) +# max_value = new_max_value +# for i in range(0, H, B1): +# offset = tl.arange(i, i + B1) +# x = tl.load(x_ptr + offset, mask=offset < H, other=0) +# z = tl.exp(x - max_value) +# z = z / denominator +# tl.store(z_ptr + offset, z, mask=offset < H) + + +# TODO: what if we just write separate kernel for this? +# TODO: can we fuse this in attention kernel? +@torch.no_grad() +def matmul_and_split_qkv(x, weight, bias, num_heads): + # x: (batch_size, seqlen, hidden_size) + x = fused_ffn(x, weight, bias=bias) + # (batch_size, seqlen, 3 * hidden_size) + batch_size, seqlen, hidden_size = x.shape + assert hidden_size % 3 == 0, hidden_size + hidden_size = hidden_size // 3 + q, k, v = x.split(hidden_size, dim=2) + assert hidden_size % num_heads == 0, (hidden_size, num_heads) + head_size = hidden_size // num_heads + # (batch_size, seqlen, num_heads, head_size) + # TODO: following is unecessary read & write - memory bound operation + q, k, v = map( + lambda x: x.view(batch_size, seqlen, num_heads, head_size) + .transpose(1, 2) + .contiguous(), + (q, k, v), + ) + # (batch_size, num_heads, seqlen, head_size) + return q, k, v + + +# TODO: does triton re-compile when different tl.constexpr is passed? +# TODO: read about flash-2 and see if we can switch to that +# TODO: then read about flash-3 and see if we can switch to that instead +# TODO: can we do score computation for only unmasked positions? +# pytorch flex-attention does something like that - it would make computation 50% efficient +@triton.jit +def flash_attention_v1_kernel( + q_ptr, + k_ptr, + v_ptr, + z_ptr, + BN, + Lq, + Lk, + scale, + H: tl.constexpr, + dropout_prob=0.0, + seed=1337, + BLOCK_SIZE_L: tl.constexpr = 64, +): + # f = (q @ k.T) / math.sqrt(head_size) + # f = dropout(F.softmax(apply_causal_mask(f), dim=-1)) + # f = f @ v + + # q, z: (B * N, Lq, H) + # k, v: (B * N, Lk, H) + + q_ptr += tl.program_id(0) * (Lq * H) + z_ptr += tl.program_id(0) * (Lq * H) + k_ptr += tl.program_id(0) * (Lk * H) + v_ptr += tl.program_id(0) * (Lk * H) + + # assuming that `H` can stay SRAM fully and doesn't require blocking + # this assumptions was made for original implementation of flash attention as well + # its reasonable as most of LLMs use head size <= 256 + offs_lq = tl.program_id(1) * BLOCK_SIZE_L + tl.arange(0, BLOCK_SIZE_L) + offs_h = tl.arange(0, H) + + q_mask = offs_lq[:, None] < Lq + q_offs = offs_lq[:, None] * H + offs_h[None, :] + # this remains in sram throughtout computation + q = tl.load(q_ptr + q_offs, mask=q_mask, other=0.0) + # (BLOCK_SIZE_L, H) + + q = q.to(tl.float16) + + # loop over k, v and compute attention & weighted v + z = tl.zeros((BLOCK_SIZE_L, H), dtype=tl.float32) + max_value = tl.zeros((BLOCK_SIZE_L, 1), dtype=tl.float32) + float("-inf") + denominator = tl.zeros((BLOCK_SIZE_L, 1), dtype=tl.float32) + for i in range(0, Lk, BLOCK_SIZE_L): + offs_lk = i + tl.arange(0, BLOCK_SIZE_L) + kv_mask = offs_lk[:, None] < Lk + kv_offs = offs_lk[:, None] * H + offs_h[None, :] + + k = tl.load(k_ptr + kv_offs, mask=kv_mask, other=0.0) + # (BLOCK_SIZE_L, H) + + k = k.to(q.dtype) + qk = tl.dot(q, k.trans(1, 0)) * scale + # (BLOCK_SIZE_L, BLOCK_SIZE_L) + + # TODO: remove eventually, its for debugging + # qk_offs = offs_lq[:, None] * Lk + offs_lk[None, :] + # tl.store(z_ptr + qk_offs, qk) + + # apply causal mask ; we still compute the attention over the future blocks + # we wanna optimise that eventually + qk = tl.where(offs_lq[:, None] >= offs_lk[None, :], qk, float("-inf")) + + block_max_value = tl.max(qk, axis=1, keep_dims=True) + # (BLOCK_SIZE_L, 1) + new_max_value = tl.where( + block_max_value > max_value, block_max_value, max_value + ) + # (BLOCK_SIZE_L, 1) + + qk = tl.exp(qk - new_max_value) + # (BLOCK_SIZE_L, BLOCK_SIZE_L) + + multiplier = tl.exp(max_value - new_max_value) + denominator *= multiplier + z *= multiplier + + denominator += tl.sum(qk, axis=1, keep_dims=True) + max_value = new_max_value + # (BLOCK_SIZE_L, 1) + + if dropout_prob > 0.0: + qk_offs = offs_lq[:, None] * Lk + offs_lk[None, :] + qk = dropout(qk, dropout_prob, seed, qk_offs) + + v = tl.load(v_ptr + kv_offs, mask=kv_mask, other=0.0) + # (BLOCK_SIZE_L, H) + + v = v.to(q.dtype) + qk = qk.to(q.dtype) + + z = tl.dot(qk, v, acc=z) + # (BLOCK_SIZE_L, H) + + z /= denominator + z = z.to(z_ptr.dtype.element_ty) + + tl.store(z_ptr + q_offs, z, mask=q_mask) + + +@torch.no_grad() +def flash_attention_v1(q, k, v, dropout_prob=0.0): + # (batch_size, num_heads, seqlen, head_size) + assert q.shape[:2] == k.shape[:2] + assert q.shape[-1] == k.shape[-1] + assert k.shape == v.shape + # B: batch_size + # N: num_heads + # L: seqlen + # H: head_size + B, N, Lq, H = q.shape + Lk = k.shape[2] + + assert H in {16, 32, 64, 128, 256} + # above condition is necessary because shared memory is limited + # and we don't do additional blocking over head_size dim + + q = q.view(B * N, Lq, H) + k = k.view(B * N, Lk, H) + v = v.view(B * N, Lk, H) + + z = torch.empty_like(q) + + # z = torch.rand((B * N, Lq, Lk), dtype=q.dtype, device=q.device) + + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert z.is_contiguous() + + scale = 1 / math.sqrt(H) + + BLOCK_SIZE_L = 64 + grid = (B * N, triton.cdiv(Lq, BLOCK_SIZE_L), 1) + flash_attention_v1_kernel[grid]( + q, + k, + v, + z, + B * N, + Lq, + Lk, + scale, + H, + dropout_prob=dropout_prob, + BLOCK_SIZE_L=BLOCK_SIZE_L, + # num_warps=8, + ) + return z.view(B, N, Lq, H) diff --git a/docker/end_to_end/kernels_attn.py b/docker/end_to_end/kernels_attn.py new file mode 100644 index 0000000..b28da08 --- /dev/null +++ b/docker/end_to_end/kernels_attn.py @@ -0,0 +1,437 @@ +import torch +import os +import math +import triton +import triton.language as tl + +import txl +from triton.tools.tensor_descriptor import TensorDescriptor +import triton.profiler.language as pl +import triton.profiler as proton +from txl.language.semantic import TXLSemantic + +def _host_descriptor_pre_hook(nargs): + NUM_CONSUMER_GROUPS = nargs.get("NUM_CONSUMERS", 1) + BLOCK_M = nargs["BLOCK_M"] // NUM_CONSUMER_GROUPS + BLOCK_N = nargs["BLOCK_N"] + HEAD_DIM = nargs["HEAD_DIM"] + FP8_OUTPUT = nargs.get("FP8_OUTPUT", False) + if not isinstance(nargs["desc_q"], TensorDescriptor): + return + nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] + if FP8_OUTPUT: + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] + else: + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] + +@triton.jit +def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): + if isinstance(desc_or_ptr, tl.tensor_descriptor): + return desc_or_ptr + else: + return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) + +# GPT2 at HEAD_DIM = 64 +tma_ws_best_config = {'BLOCK_M':128, 'BLOCK_N':64, 'NUM_CONSUMERS': 2, 'NUM_STAGES': 2} + +@txl.autotune( + configs=[ + txl.Config( + tma_ws_best_config, + num_stages=2, + num_warps=4, + num_warpgroups=3, + pre_hook = _host_descriptor_pre_hook, + ) + ], + key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT"], + ) +@txl.jit +def _attn_fwd_ws_tma_txl3_causal(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + + # NOTE: txl + NUM_STAGES: tl.constexpr, # + NUM_CONSUMERS: tl.constexpr # + ): + dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + y_dim = Z * H * N_CTX + # If no host desc, then make device desc + desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + if FP8_OUTPUT: + y_dim_v = Z * H * HEAD_DIM + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim_v, N_CTX], strides=[N_CTX, 1], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + + + # load q: it will stay in SRAM throughout + #q = desc_q.load([qo_offset_y, 0]) + bQ0 = txl.smem_alloc([BLOCK_M//2, HEAD_DIM], dtype=dtype) # bQ has only 1 buffer for reuse only + pMbar_bQ0 = txl.mbar_alloc(1) + bQ1 = txl.smem_alloc([BLOCK_M//2, HEAD_DIM], dtype=dtype) + pMbar_bQ1 = txl.mbar_alloc(1) + + bK = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) + if FP8_OUTPUT: + bV = txl.smem_alloc([HEAD_DIM, BLOCK_N], dtype=dtype, num_stages=NUM_STAGES) + else: + bV = txl.smem_alloc([BLOCK_N, HEAD_DIM], dtype=dtype, num_stages=NUM_STAGES) + pMbar_bK = txl.mbar_alloc(1, num_stages=NUM_STAGES) + pMbar_bV = txl.mbar_alloc(1, num_stages=NUM_STAGES) + + cMbar_QK1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_PV1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_QK2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + cMbar_PV2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + + #WG1_BAR = 8 + #WG2_BAR = 9 + #WG_NUM_THREADS = 128 * 2 + + # TODO: func type mismatch + + # range of values handled by this stage + lo, hi = 0, (start_m + 1) * BLOCK_M + mask_begin = start_m * BLOCK_M + offsetkv_y = offset_y + lo + + + if txl.is_warpgroup([0]): + txl.reg_dealloc(24) + + bQ0i = txl.get_buffer(bQ0, 0) + pMbar_bQ0i = txl.get_buffer(pMbar_bQ0, 0) + bQ1i = txl.get_buffer(bQ1, 0) + pMbar_bQ1i = txl.get_buffer(pMbar_bQ1, 0) + + txl.mbar_expect(pMbar_bQ0i, BLOCK_M // 2 * HEAD_DIM * byte_count) + txl.tma_load(bQ0i, desc_q, [qo_offset_y, 0], pMbar_bQ0i) + txl.mbar_wait(pMbar_bQ0i, 0) + txl.mbar_expect(pMbar_bQ1i, BLOCK_M // 2 * HEAD_DIM * byte_count) + txl.tma_load(bQ1i, desc_q, [qo_offset_y+BLOCK_M//2, 0], pMbar_bQ1i) + txl.mbar_wait(pMbar_bQ1i, 0) + + bufIdxW = 0 # write buffer + phase = 1 + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + cur_mbar_bK = txl.get_buffer(pMbar_bK, bufIdxW) + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxW) + cur_bK = txl.get_buffer(bK, bufIdxW) + cur_bV = txl.get_buffer(bV, bufIdxW) + + cur_mbar_QK1 = txl.get_buffer(cMbar_QK1, bufIdxW) # wait for the same buffer + cur_mbar_PV1 = txl.get_buffer(cMbar_PV1, bufIdxW) + cur_mbar_QK2 = txl.get_buffer(cMbar_QK2, bufIdxW) + cur_mbar_PV2 = txl.get_buffer(cMbar_PV2, bufIdxW) + + # TODO: tma_expect_and_load + txl.mbar_wait(cur_mbar_QK1, phase) + txl.mbar_wait(cur_mbar_QK2, phase) + txl.mbar_expect(cur_mbar_bK, BLOCK_N * HEAD_DIM * byte_count) + txl.tma_load(cur_bK, desc_k, [offsetkv_y, 0], cur_mbar_bK) + + txl.mbar_wait(cur_mbar_PV1, phase) + txl.mbar_wait(cur_mbar_PV2, phase) + txl.mbar_expect(cur_mbar_bV, BLOCK_N * HEAD_DIM * byte_count) + if FP8_OUTPUT: + txl.tma_load(cur_bV, desc_v, [off_hz * HEAD_DIM, start_n], cur_mbar_bV) + else: + txl.tma_load(cur_bV, desc_v, [offsetkv_y, 0], cur_mbar_bV) + + offsetkv_y += BLOCK_N + bufIdxW = (bufIdxW + 1) % NUM_STAGES + if bufIdxW == 0: + phase = phase^1 + + + if txl.is_warpgroup([1, 2]): + txl.reg_alloc(240) + + if txl.is_warpgroup([1]): + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2) + + # first let wg1 to start + #txl.bar_arrive(WG1_BAR, WG_NUM_THREADS) + txl.bar_arrive(8, 256) + else: + offs_m = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + + # initialize pointer to m and l + # These are in regs + m_i = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + + ## load and wait Q + bQ0i = txl.get_buffer(bQ0, 0) + pMbar_bQ0i = txl.get_buffer(pMbar_bQ0, 0) + bQ1i = txl.get_buffer(bQ1, 0) + pMbar_bQ1i = txl.get_buffer(pMbar_bQ1, 0) + + if txl.is_warpgroup([1]): + txl.mbar_wait(pMbar_bQ0i, 0) + # WG1 just start + #txl.bar_wait(WG1_BAR, WG_NUM_THREADS) + if txl.is_warpgroup([2]): + txl.mbar_wait(pMbar_bQ1i, 0) + # WG2 start after wg1 gemm0 + #txl.bar_wait(WG2_BAR, WG_NUM_THREADS) + + + # -- prologue -- + + # TODO: write in txl.jit for reuse + ## load and wait K + cur_mbar_bK = txl.get_buffer(pMbar_bK, 0) + cur_bK = txl.get_buffer(bK, 0) + txl.mbar_wait(cur_mbar_bK, 0) + + if txl.is_warpgroup([1]): + cur_mbar_QK = txl.get_buffer(cMbar_QK1, 0) + qk = tl.dot(bQ0i, cur_bK.T) + + # TODO whether before dot wait? + #txl.bar_arrive(WG2_BAR, WG_NUM_THREADS) + txl.dot_wait(0) + txl.mbar_arrive(cur_mbar_QK) + + else: # [2] + cur_mbar_QK = txl.get_buffer(cMbar_QK2, 0) + qk = tl.dot(bQ1i, cur_bK.T) + + # TODO whether before dot wait? + #txl.bar_arrive(WG1_BAR, WG_NUM_THREADS) + txl.dot_wait(0) + txl.mbar_arrive(cur_mbar_QK) + + if lo >= mask_begin: + mask = offs_m[:, None] >= (lo + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + #acc = acc * alpha[:, None] + + # update m_i and l_i + m_i = m_ij + + # update acc + p = p.to(dtype) + + bufIdxRK = 1 + bufIdxRV = 0 + phaseK = 0 + phaseV = 0 + + # pass: p, l_i, m_i, acc + # loop over k, v and update accumulator + for start_n in range(lo+BLOCK_N, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + # -- load k ---- + cur_mbar_bK = txl.get_buffer(pMbar_bK, bufIdxRK) + cur_bK = txl.get_buffer(bK, bufIdxRK) + txl.mbar_wait(cur_mbar_bK, phaseK) + + # Now only consider gemm 0 and softmax(gemm 1) + # --- wait to start gemm 1 --- + # case 1: wg1 earlys start + # case 2: wait the release from last iter gemm 0 ends? + if txl.is_warpgroup([1]): + # WG1 just start + #txl.bar_wait(WG1_BAR, WG_NUM_THREADS) + txl.bar_wait(8, 256) + if txl.is_warpgroup([2]): + # WG2 start after wg1 gemm0 + #txl.bar_wait(WG2_BAR, WG_NUM_THREADS) + txl.bar_wait(9, 256) + + if txl.is_warpgroup([1]): + cur_mbar_QK = txl.get_buffer(cMbar_QK1, bufIdxRK) + cur_mbar_PV = txl.get_buffer(cMbar_PV1, bufIdxRV) + qk = tl.dot(bQ0i, cur_bK.T) + + else: # [2] + cur_mbar_QK = txl.get_buffer(cMbar_QK2, bufIdxRK) + cur_mbar_PV = txl.get_buffer(cMbar_PV2, bufIdxRV) + qk = tl.dot(bQ1i, cur_bK.T) + + # -- compute pv j-1 ---- + # load v + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxRV) + cur_bV = txl.get_buffer(bV, bufIdxRV) + txl.mbar_wait(cur_mbar_bV, phaseV) + + ## Downgrade if put before load v + #if txl.is_warpgroup([1]): + # txl.bar_wait(WG1_BAR, WG_NUM_THREADS) + #else: + # txl.bar_wait(WG2_BAR, WG_NUM_THREADS) + + # note that this non transposed v for FP8 is only supported on Blackwell + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) + + + # TODO: before or after wait? oh previously is also before QK wait + if txl.is_warpgroup([1]): + #txl.bar_arrive(WG2_BAR, WG_NUM_THREADS) + txl.bar_arrive(9, 256) + else: + #txl.bar_arrive(WG1_BAR, WG_NUM_THREADS) + txl.bar_arrive(8, 256) + txl.dot_wait(1) + # --- release QK finished --- + txl.mbar_arrive(cur_mbar_QK) + + #m_i, l_i, p, alpha = softmax_txl(m_i, l_i, qk, qk_scale, dtype) + + # -- compute softamx, block arg updates ---- + if start_n >= mask_begin: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + + # udpate p + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + + # update output accumulator + txl.dot_wait(0) + # --- release PV j-1 finished --- + txl.mbar_arrive(cur_mbar_PV) + + acc = acc * alpha[:, None] + + # update acc, NOTE: p position is important + p = p.to(dtype) + + bufIdxRK = (bufIdxRK + 1) % NUM_STAGES + if bufIdxRK == 0: + phaseK = phaseK ^ 1 + bufIdxRV = (bufIdxRV + 1) % NUM_STAGES + if bufIdxRV == 0: + phaseV = phaseV ^ 1 + + + #if txl.is_warpgroup([1]): + # txl.bar_arrive(WG2_BAR, WG_NUM_THREADS) + #else: + # txl.bar_arrive(WG1_BAR, WG_NUM_THREADS) + + # -- last iter -- + # load v + cur_mbar_bV = txl.get_buffer(pMbar_bV, bufIdxRV) + #if txl.is_warpgroup([1]): + # cur_mbar_PV = txl.get_buffer(cMbar_PV1, bufIdxRV) + #else: + # cur_mbar_PV = txl.get_buffer(cMbar_PV2, bufIdxRV) + cur_bV = txl.get_buffer(bV, bufIdxRV) + txl.mbar_wait(cur_mbar_bV, phaseV) + + # note that this non transposed v for FP8 is only supported on Blackwell + if FP8_OUTPUT: + acc = tl.dot(p, cur_bV.T, acc) + else: + acc = tl.dot(p, cur_bV, acc) + txl.dot_wait(0) + #txl.mbar_arrive(cur_mbar_PV) + + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + + if txl.is_warpgroup([1]): + desc_o.store([qo_offset_y, 0], acc.to(dtype)) + if txl.is_warpgroup([2]): + desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc.to(dtype)) + +@torch.no_grad() +def flash_attention_v2(q, k, v, dropout_prob=0.0): + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + extra_kern_args = {} + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + y_dim = q.shape[0] * q.shape[1] * q.shape[2] + + dummy_block = [1, 1] + desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + def grid(META): + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + + sm_scale = 1 / math.sqrt(HEAD_DIM_K) + + _attn_fwd_ws_tma_txl3_causal[grid]( + sm_scale, M, # + q.shape[0], q.shape[1], # + desc_q, desc_k, desc_v, desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + FP8_OUTPUT=q.dtype == torch.float8_e5m2, # + **extra_kern_args) + + return o \ No newline at end of file diff --git a/docker/end_to_end/kernels_ffn.py b/docker/end_to_end/kernels_ffn.py new file mode 100644 index 0000000..f39ba4b --- /dev/null +++ b/docker/end_to_end/kernels_ffn.py @@ -0,0 +1,278 @@ +import torch +import os +import math +import triton +import triton.language as tl + +import txl +from triton.tools.tensor_descriptor import TensorDescriptor +import triton.profiler.language as pl +import triton.profiler as proton +from txl.language.semantic import TXLSemantic + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False) + ws_str = "_ws" if WS else "" + ret["name"] = f"{kernel.name}{ws_str} [M={M}, N={N}, K={K}]" + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) + return ret + +def matmul_tma_set_block_size_hook(nargs): + EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False) + NUM_CONSUMER_GROUPS = nargs.get("NUM_CONSUMER_GROUPS", 1) + BLOCK_M = nargs["BLOCK_SIZE_M"] + BLOCK_M //= NUM_CONSUMER_GROUPS + BLOCK_N = nargs["BLOCK_SIZE_N"] + BLOCK_K = nargs["BLOCK_SIZE_K"] + nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] + nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] + if EPILOGUE_SUBTILE: + nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2] + else: + nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] + +@txl.autotune( + configs=[ + txl.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "NUM_CONSUMER_GROUPS": 2, + "NUM_STAGES": 2, + }, + num_stages=2, + num_warps=4, + num_warpgroups=3, + pre_hook=matmul_tma_set_block_size_hook + ), + ], + key=["M", "N", "K"], + use_cuda_graph=True, +) + +@txl.jit(launch_metadata=_matmul_launch_metadata) +def matmul_persistent_ws_tma_txl_kernel( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_CONSUMER_GROUPS: tl.constexpr, + NUM_STAGES: tl.constexpr, + + # 3.4.x + #EPILOGUE_SUBTILE: tl.constexpr, # + NUM_SMS: tl.constexpr, # + WARP_SPECIALIZE: tl.constexpr, # + + # fnn + apply_gelu: tl.constexpr=False, + bias_ptr=None, + residual_ptr=None, +): + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + byte_count: tl.constexpr = 2 if dtype == tl.float16 else 1 + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + a0 = txl.smem_alloc([BLOCK_SIZE_M//2, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + a1 = txl.smem_alloc([BLOCK_SIZE_M//2, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + b0 = txl.smem_alloc([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=dtype, num_stages=NUM_STAGES) + + mbar_producer_a0 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_producer_a1 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_producer_b0 = txl.mbar_alloc(1, num_stages=NUM_STAGES) + mbar_consumer1 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + mbar_consumer2 = txl.mbar_alloc(128, num_stages=NUM_STAGES) + + + if txl.is_warpgroup([0]): + + phase = 1 + bufIdx = 0 + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + #pid_m = pid % num_pid_m + #pid_n = pid // num_pid_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mbar_c1 = txl.get_buffer(mbar_consumer1, bufIdx) + mbar_c2 = txl.get_buffer(mbar_consumer2, bufIdx) + + mbar_p_a0 = txl.get_buffer(mbar_producer_a0, bufIdx) + mbar_p_a1 = txl.get_buffer(mbar_producer_a1, bufIdx) + mbar_p_b0 = txl.get_buffer(mbar_producer_b0, bufIdx) + + a0_buf = txl.get_buffer(a0, bufIdx) + a1_buf = txl.get_buffer(a1, bufIdx) + b0_buf = txl.get_buffer(b0, bufIdx) + + txl.mbar_wait(mbar_c1, phase) + txl.mbar_expect(mbar_p_a0, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) + txl.tma_load(a0_buf, a_desc, [offs_am, offs_k], mbar_p_a0) + + txl.mbar_wait(mbar_c2, phase) + txl.mbar_expect(mbar_p_b0, BLOCK_SIZE_N*BLOCK_SIZE_K*byte_count) + txl.tma_load(b0_buf, b_desc, [offs_bn, offs_k], mbar_p_b0) + + + txl.mbar_expect(mbar_p_a1, BLOCK_SIZE_M//2*BLOCK_SIZE_K*byte_count) + txl.tma_load(a1_buf, a_desc, [offs_am + BLOCK_SIZE_M // 2, offs_k], mbar_p_a1) + + offs_k += BLOCK_SIZE_K + bufIdx = (bufIdx + 1) % NUM_STAGES + if bufIdx == 0: + phase = phase^1 + + if txl.is_warpgroup([1, 2]): # TODO: else + phase = 0 + bufIdx = 0 + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + #pid_m = pid % num_pid_m + #pid_n = pid // num_pid_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + accumulator = tl.zeros((BLOCK_SIZE_M//2, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + mbar_p_b0 = txl.get_buffer(mbar_producer_b0, bufIdx) + + b0_buf = txl.get_buffer(b0, bufIdx) + txl.mbar_wait(mbar_p_b0, phase) + if txl.is_warpgroup([1]): + mbar_p_a0 = txl.get_buffer(mbar_producer_a0, bufIdx) + mbar_c1 = txl.get_buffer(mbar_consumer1, bufIdx) + a0_buf = txl.get_buffer(a0, bufIdx) + txl.mbar_wait(mbar_p_a0, phase) + accumulator = tl.dot(a0_buf, b0_buf.T, accumulator) # accumulator is reg, no contention among buffers + txl.dot_wait(0) + txl.mbar_arrive(mbar_c1) + if txl.is_warpgroup([2]): # TODO: else test + mbar_p_a1 = txl.get_buffer(mbar_producer_a1, bufIdx) + mbar_c2 = txl.get_buffer(mbar_consumer2, bufIdx) + a1_buf = txl.get_buffer(a1, bufIdx) + txl.mbar_wait(mbar_p_a1, phase) + accumulator = tl.dot(a1_buf, b0_buf.T, accumulator) + txl.dot_wait(0) + txl.mbar_arrive(mbar_c2) + + offs_k += BLOCK_SIZE_K + bufIdx = (bufIdx + 1) % NUM_STAGES + if bufIdx == 0: # TODO: pipelinestate + phase = phase^1 + + offs_n = offs_bn + tl.arange(0, BLOCK_SIZE_N)[None, :] + + if bias_ptr is not None: + bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0) + accumulator += bias.to(tl.float32) + if apply_gelu: + a:tl.constexpr = 0.797885 + b:tl.constexpr = 0.044715 + x = accumulator + x_cube = x * x * x + tanh_res = 2.0 * tl.sigmoid(2.0 * a * (x + b * x_cube)) - 1.0 + accumulator = 0.5 * x * (1.0 + tanh_res) + if residual_ptr is not None: + if txl.is_warpgroup([1]): + offs_m = offs_am + tl.arange(0, BLOCK_SIZE_M//2)[:, None] + else: + offs_m = offs_am + BLOCK_SIZE_M//2 + tl.arange(0, BLOCK_SIZE_M//2)[:, None] + mask = (offs_m < M) & (offs_n < N) + residual = tl.load(residual_ptr + offs_m * N + offs_n, mask=mask, other=0.0) + accumulator += residual.to(tl.float32) + + c = accumulator.to(dtype) + if txl.is_warpgroup([1]): + c_desc.store([offs_am, offs_bn], c) + if txl.is_warpgroup([2]): + c_desc.store([offs_am + BLOCK_SIZE_M//2, offs_bn], c) + +@torch.no_grad() +def fused_ffn_txl( + a, + b, + bias=None, + residual=None, + add_gelu=False, + dropout_prob=0.0, +): + out_shape_0 = a.shape[:-1] + a = a.view(-1, a.shape[-1]) + b = b.T.contiguous() + + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + if bias is not None: # bias: (N,) + assert bias.is_contiguous() + assert b.shape[0] == bias.shape[0] + if residual is not None: + residual = residual.view(c.shape) + assert residual.is_contiguous() + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # A dummy block value that will be overwritten when we have the real block size + dummy_block = [1, 1] + a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) + b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) + c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) + + def grid(META): + nonlocal a_desc, b_desc, c_desc + BLOCK_M = META["BLOCK_SIZE_M"] + BLOCK_N = META["BLOCK_SIZE_N"] + return (min( + NUM_SMS, + triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + ), ) + + if dtype == torch.float16: + matmul_persistent_ws_tma_txl_kernel[grid]( + a_desc, b_desc, c_desc, # + M, N, K, # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + NUM_SMS=NUM_SMS, # + WARP_SPECIALIZE=False, # + apply_gelu=add_gelu, + bias_ptr=bias, + residual_ptr=residual, + ) + return c.view(*out_shape_0, N) \ No newline at end of file diff --git a/docker/end_to_end/modal_test.py b/docker/end_to_end/modal_test.py new file mode 100644 index 0000000..0fde161 --- /dev/null +++ b/docker/end_to_end/modal_test.py @@ -0,0 +1,123 @@ +from modal import Image, App, Volume +import pathlib +local_dir = pathlib.Path(__file__).parent +requirements_file = local_dir / "requirements.txt" +txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" + +test_file_bench = local_dir / "bench.py" +test_file_gpt = local_dir / "gpt.py" +test_file_kernels = local_dir / "kernels.py" +test_file_test = local_dir / "test.py" +test_file_tmp = local_dir / "tmp.py" +test_file_kernels_attn = local_dir / "kernels_attn.py" +test_file_kernels_ffn = local_dir / "kernels_ffn.py" + +app = App(name="txl") # Note: this is optional since Modal 0.57 +volume = Volume.from_name("txl-dump", create_if_missing=True) # create a cloud volume to store compiled dump files + +txl_image = ( + Image.from_registry( + "nvidia/cuda:12.4.0-devel-ubuntu22.04", + add_python="3.12", + ) + #Image.from_dockerfile(path="./Dockerfile") + .workdir("/workspace") + .add_local_file(txl_wheel_file, remote_path="/workspace/", copy=True) # copy the local code to the image + .run_commands( "ls .") + .pip_install_from_requirements(requirements_file) # local file not remote file + .pip_install("jaxtyping") + .run_commands( + "pip install /workspace/txl-3.4.0-cp312-cp312-linux_x86_64.whl", + ) + .add_local_file(test_file_bench, remote_path="/workspace/bench.py", copy=False) # copy after image build, no need rebuild + .add_local_file(test_file_gpt, remote_path="/workspace/gpt.py", copy=False) + .add_local_file(test_file_kernels, remote_path="/workspace/kernels.py", copy=False) + .add_local_file(test_file_test, remote_path="/workspace/test_txl.py", copy=False) + .add_local_file(test_file_tmp, remote_path="/workspace/tmp.py", copy=False) + .add_local_file(test_file_kernels_attn, remote_path="/workspace/kernels_attn.py", copy=False) + .add_local_file(test_file_kernels_ffn, remote_path="/workspace/kernels_ffn.py", copy=False) +) + +# Example function that uses the image +@app.function(gpu="H100", image=txl_image, timeout=60*60, + volumes={"/workspace/dump": volume}) +def run_demo(): + import subprocess, sys, os, torch, time + def get_gpu_type(): + + try: + # Execute nvidia-smi command to query GPU details + result = subprocess.run(['nvidia-smi', '-q'], capture_output=True, text=True, check=True) + output = result.stdout + + # Look for indicators of SXM or PCIe in the output + for line in output.split("\n"): + if "Product Name" in line: + print(line) + if 'H100' in line and 'HBM3' in line: + return True + except subprocess.CalledProcessError as e: + print(f"Error running nvidia-smi: {e}") + except FileNotFoundError: + print("nvidia-smi not found. Please ensure NVIDIA drivers are installed and in your PATH.") + return False + + def test_demo(): + os.makedirs("/workspace/dump", exist_ok=True) + logs_dir = pathlib.Path("/workspace/dump/logs") + logs_dir.mkdir(parents=True, exist_ok=True) + ts = time.strftime("%Y%m%d-%H%M%S") + log_path = logs_dir / f"mla-{ts}.log" + + env = os.environ.copy() + # env["TRITON_PRINT_AUTOTUNING"] = "0" + # env["TRITON_KERNEL_DUMP"] = "1" + # env["TRITON_DUMP_DIR"] = "/workspace/dump" + # env["TRITON_ALWAYS_COMPILE"] = "1" + # env["CUDA_LAUNCH_BLOCKING"] = "1" + # env["TRITON_LLVM_DEBUG_ONLY"] = "txlgpu-pipeliner" + + cmd = [sys.executable, "-u", "/workspace/bench.py"] + + with open(log_path, "w", buffering=1, encoding="utf-8", errors="replace") as f: + proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, bufsize=1 + ) + assert proc.stdout is not None + for line in proc.stdout: + print(line, end="") + f.write(line) + + rc = proc.wait() + + print(f"\n=== FULL LOG SAVED ===\n{log_path}\n") + if rc != 0: + raise SystemExit(rc) + + def import_cuBLAS_lib(): + import os, ctypes, pathlib + import nvidia.cublas, nvidia.cuda_runtime + + cublas_dir = (pathlib.Path(nvidia.cublas.__file__).parent / "lib").resolve() + cudart_dir = (pathlib.Path(nvidia.cuda_runtime.__file__).parent / "lib").resolve() + + os.environ["LD_LIBRARY_PATH"] = f"{cublas_dir}:{cudart_dir}:" + os.environ.get("LD_LIBRARY_PATH","") + + for name12, name in [("libcublas.so.12","libcublas.so"), + ("libcublasLt.so.12","libcublasLt.so")]: + src = cublas_dir / name12 + dst = cublas_dir / name + if src.exists() and not dst.exists(): + try: dst.symlink_to(name12) + except FileExistsError: pass + except PermissionError: pass + + ctypes.CDLL(str(cudart_dir / "libcudart.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublasLt.so.12"), mode=ctypes.RTLD_GLOBAL) + ctypes.CDLL(str(cublas_dir / "libcublas.so.12"), mode=ctypes.RTLD_GLOBAL) + + if get_gpu_type(): + print("Running on H100 SXM") + import_cuBLAS_lib() + test_demo() \ No newline at end of file diff --git a/docker/end_to_end/requirements.txt b/docker/end_to_end/requirements.txt new file mode 100644 index 0000000..cdd4b13 --- /dev/null +++ b/docker/end_to_end/requirements.txt @@ -0,0 +1,6 @@ +torch +transformers +numpy<2.0.0 +black +isort +ipdb diff --git a/docker/end_to_end/test.py b/docker/end_to_end/test.py new file mode 100644 index 0000000..9f594e4 --- /dev/null +++ b/docker/end_to_end/test.py @@ -0,0 +1,152 @@ +# TRITON_INTERPRET=1 pytest -sv test.py + +import math + +import pytest +import torch +import torch.nn as nn +from transformers import AutoTokenizer +from transformers.activations import ACT2FN + +from gpt import (FusedGPT, GPTConfig, convert_hf_and_load_model, estimate_days, + get_num_parameters) +from kernels import (flash_attention_v1, fused_embeddings, fused_ffn, + fused_layer_norm) + + +def _get_inputs(M, K, N, device): + torch.manual_seed(1337) + x = torch.rand((M, K), device=device, dtype=torch.float32) + w = torch.rand((K, N), device=device, dtype=torch.float32) + b = torch.rand((N,), device=device, dtype=torch.float32) + r = torch.rand_like(x, dtype=torch.float32) + if K != N: + r = r_torch = None + return x, w, b, r + + +@pytest.mark.parametrize("vocab_size", [2, 32]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("hidden_size", [32, 128, 256]) +@pytest.mark.parametrize("seqlen, block_size", [(10, 20), (20, 20)]) +def test_fused_embeddings(batch_size, seqlen, vocab_size, block_size, hidden_size): + device = "cuda" if torch.cuda.is_available() else "cpu" + + x = torch.randint( + 0, vocab_size, size=(batch_size, seqlen), dtype=torch.long, device=device + ) + wte = torch.rand((vocab_size, hidden_size), device=device) + wpe = torch.rand((block_size, hidden_size), device=device) + + z_torch = wte[x] + wpe[torch.arange(x.shape[1], device=device)][None] + z = fused_embeddings(x, wte, wpe) + + assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() + + +@pytest.mark.parametrize("M", [249, 32]) +@pytest.mark.parametrize("K", [123, 128, 64]) +def test_fused_layer_norm(M, K): + N = 32 + device = "cuda" if torch.cuda.is_available() else "cpu" + x, *_ = _get_inputs(M, K, N, device) + x_torch, *_ = _get_inputs(M, K, N, device) + + layer_norm = nn.LayerNorm(K).to(device) + x_torch = layer_norm(x_torch) + x = fused_layer_norm(x, layer_norm.weight.data, layer_norm.bias.data) + + assert torch.allclose(x, x_torch, atol=1e-5), (x - x_torch).abs().max() + + +def torch_ffn(x, w, b=None, r=None): + z = x @ w + if b is not None: + z += b + z = ACT2FN["gelu_new"](z) + if r is not None: + z += r + return z + + +@pytest.mark.parametrize("M,N,K", [(128, 128, 256), (199, 129, 129), (61, 31, 23)]) +@pytest.mark.parametrize("add_gelu", [True, False]) +@pytest.mark.parametrize("add_bias", [True, False]) +def test_fused_ffn(M, N, K, add_gelu, add_bias): + device = "cuda" if torch.cuda.is_available() else "cpu" + x_torch, w_torch, b_torch, r_torch = _get_inputs(M, K, N, device) + x, w, b, r = _get_inputs(M, K, N, device) + + if not add_bias: + b_torch = None + b = None + + z_torch = torch_ffn(x_torch, w_torch, b=b_torch, r=r_torch) + + z = fused_ffn(x, w, bias=b, residual=r, add_gelu=True) + assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() + + +def _get_attn_inputs(B, N, L, H, device): + torch.manual_seed(1337) + q = torch.rand((B, N, L, H), device=device) + k = torch.rand_like(q) + v = torch.rand_like(q) + return q, k, v + + +def torch_attention(q, k, v): + assert q.shape == k.shape == v.shape + B, N, L, H = q.shape + q, k, v = map(lambda x: x.view(B * N, L, H), (q, k, v)) + z = (q @ k.transpose(1, 2)) / math.sqrt(H) + attn_mask = torch.tril(torch.ones((L, L), dtype=torch.bool)) + z = torch.where(attn_mask, z, float("-inf")) + z = z.softmax(-1) @ v + return z.view(B, N, L, H) + + +@pytest.mark.parametrize("B,N", [(3, 9), (2, 7)]) +@pytest.mark.parametrize("L", [199, 128, 63]) +@pytest.mark.parametrize("H", [64, 128, 256]) +def test_flash_attention_v1(B, N, L, H): + device = "cuda" if torch.cuda.is_available() else "cpu" + q, k, v = _get_attn_inputs(B, N, L, H, device) + z_torch = torch_attention(q, k, v) + z = flash_attention_v1(q, k, v) + assert torch.allclose(z, z_torch, atol=1e-5), (z - z_torch).abs().max() + + +def test_gpt2(): + device = "cuda" if torch.cuda.is_available() else "cpu" + model_id = "gpt2" + model, hf_model = convert_hf_and_load_model(model_id, device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + with torch.no_grad(): + string = "I am vasudev gupta. I like AI." + inputs = tokenizer(string, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + hf_out = hf_model(**inputs).last_hidden_state + out = model(inputs["input_ids"]) + print((out - hf_out).abs()) + # TODO: need to look at why we can't do low precision + assert torch.allclose(out, hf_out, atol=1e-1), (out - hf_out).abs().max() + + +def test_flops(): + config = GPTConfig() + model = FusedGPT(config).eval() + num_tokens = 1024 + fwd_flops = model.get_fwd_flops(num_tokens) + total_flops = fwd_flops * 3 + num_parameters = get_num_parameters(model) + r = (fwd_flops * 3) / (6 * num_parameters * num_tokens) + assert r >= 0.9995, r + + +def test_estimate_days(): + # llama-3.1 paper reports 54 days for pre-training 405B parameter model + # its very close to what we get from following equation + flops = 6 * (405 * 10**9) * (15 * 10**12) + t = estimate_days(flops, mfu=0.45, gpu="h100", num_gpus=16_000) + assert t == 59.24544994944388, t From 7881d7fcdc3512d5dbc48aac7866e43ee0984dd2 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Mon, 8 Dec 2025 10:14:28 +0800 Subject: [PATCH 2/6] test profile on modal --- python/txl/tutorials/02-flash-attention.py | 7 ++++--- tools/profile.sh | 10 ++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python/txl/tutorials/02-flash-attention.py b/python/txl/tutorials/02-flash-attention.py index 63a124f..1aec7b4 100644 --- a/python/txl/tutorials/02-flash-attention.py +++ b/python/txl/tutorials/02-flash-attention.py @@ -3080,10 +3080,11 @@ def run_test(algo=0, dump_dir=None): #test_op(1, 2, 1536, 128, False, dtype=torch.float16, algo=4, no_tune=no_tune, profiling=PROFILING) test_op(16, 32, 1024, 128, False, dtype=torch.float16, algo=algo, no_tune=no_tune, profiling=PROFILING) - print("BENCH...") - bench_flash_attention.run(save_path=".", print_data=True, algo=algo, no_tune=no_tune) + # print("BENCH...") + # bench_flash_attention.run(save_path=".", print_data=True, algo=algo, no_tune=no_tune) if __name__ == "__main__": #run_test(6, dump_dir='dump/fa1113') #run_test(5, dump_dir='dump/1124fa') - run_test(3, dump_dir='dump/11251dot3') + # run_test(3, dump_dir='dump/11251dot3') + run_test(4) diff --git a/tools/profile.sh b/tools/profile.sh index fbcc811..9edc8ce 100644 --- a/tools/profile.sh +++ b/tools/profile.sh @@ -1,4 +1,4 @@ -export CUDA_VISIBLE_DEVICES=1 +# export CUDA_VISIBLE_DEVICES=0 export TRITON_ALWAYS_COMPILE=1 CMD=$1 # ncu, nsys @@ -43,13 +43,15 @@ REGEX="flash|txl|coopA|softmax_kernel" # denormal METRICS="sm__sass_thread_inst_executed_op_fp16_pred_on,sm__sass_thread_inst_executed_op_fp32_pred_on" - -PY_SCRIPT=python/txl/tutorials/02-flash-attention.py +PY_SCRIPT=test_txl.py +# PY_SCRIPT=python/txl/tutorials/02-flash-attention.py #PY_SCRIPT=python/txl/tutorials/04-softmax.py #PY_SCRIPT=python/txl/tests/wgid.py #PY_SCRIPT="python/txl/tutorials/01-matmul.py -K 16384" #PY_SCRIPT="python/txl/tutorials/01-matmul.py -K 2048" +OUTPUT="/workspace/dump/nsys_profile" + # Convert comma-separated sections to multiple --section flags section_flags=() if [ -n "$SECTIONS" ]; then @@ -71,7 +73,7 @@ elif [ "$CMD" == "nsys" ]; then nsys profile \ -t cuda \ ${OUTPUT:+-o "$OUTPUT"} \ - python $PY_SCRIPT + python -u $PY_SCRIPT else echo "Usage: $0 [ncu|nsys]" exit 1 From 70c2eea43634ff588bd3bf276ea581bfabfcc1d2 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Mon, 8 Dec 2025 16:38:35 +0800 Subject: [PATCH 3/6] update bench --- docker/end_to_end/bench.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/docker/end_to_end/bench.py b/docker/end_to_end/bench.py index fafea6d..7e4d707 100644 --- a/docker/end_to_end/bench.py +++ b/docker/end_to_end/bench.py @@ -114,11 +114,25 @@ def validate(mixed_precison=True, max_length=1024): print("start") validate(mixed_precison=True, max_length=1024) -print("triton:", run_benchmark("triton")) -print("txl:", run_benchmark("txl")) -print("torch:", run_benchmark("torch")) -print("triton long:", run_benchmark("triton", max_length=4096, need_gpt=False)) -print("txl long:", run_benchmark("txl", max_length=4096, need_gpt=False)) +# print("triton:", run_benchmark("triton")) +# print("txl:", run_benchmark("txl")) +# print("torch:", run_benchmark("torch")) +# print("triton long:", run_benchmark("triton", max_length=4096, need_gpt=False)) +# print("txl long:", run_benchmark("txl", max_length=4096, need_gpt=False)) +for i in range(8, 11): + length = 2**i + torch_time = run_benchmark("torch", max_length=length) + triton_time = run_benchmark("triton", max_length=length) + txl_time = run_benchmark("txl", max_length=length) + speedup_triton = torch_time / triton_time + speedup_txl = torch_time / txl_time + print(f"length: {length}, torch: {torch_time:.2f} ms, triton: {triton_time:.2f} ms, txl: {txl_time:.2f} ms, speedup_triton: {speedup_triton:.2f}x, speedup_txl: {speedup_txl:.2f}x") +for i in range(11, 13): + length = 2**i + triton_time = run_benchmark("triton", max_length=length, need_gpt=False) + txl_time = run_benchmark("txl", max_length=length, need_gpt=False) + speedup_txl = triton_time / txl_time + print(f"length: {length}, triton: {triton_time:.2f} ms, txl: {txl_time:.2f} ms, speedup_txl: {speedup_txl:.2f}x") # print("txl test:", run_benchmark("txl", max_length=1024, need_gpt=False)) # print("triton test:", run_benchmark("triton", max_length=1024, need_gpt=False)) # OLD SUMMARY From 258099396a5011c7470be739a5e2fb8e858a2ff0 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Mon, 8 Dec 2025 16:39:03 +0800 Subject: [PATCH 4/6] update bench --- docker/end_to_end/gpt.py | 2 +- docker/end_to_end/modal_test.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docker/end_to_end/gpt.py b/docker/end_to_end/gpt.py index 7784ebb..90e2dda 100644 --- a/docker/end_to_end/gpt.py +++ b/docker/end_to_end/gpt.py @@ -355,7 +355,7 @@ def load_model_only(device): config = GPTConfig( vocab_size=50304, block_size=16384, - n_layer=24, + n_layer=12, n_head=12, n_embd=768, dropout=0.0, diff --git a/docker/end_to_end/modal_test.py b/docker/end_to_end/modal_test.py index 0fde161..7e37eb6 100644 --- a/docker/end_to_end/modal_test.py +++ b/docker/end_to_end/modal_test.py @@ -1,8 +1,9 @@ from modal import Image, App, Volume import pathlib local_dir = pathlib.Path(__file__).parent +root_dir = local_dir.parent requirements_file = local_dir / "requirements.txt" -txl_wheel_file = local_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" +txl_wheel_file = root_dir / "txl-3.4.0-cp312-cp312-linux_x86_64.whl" test_file_bench = local_dir / "bench.py" test_file_gpt = local_dir / "gpt.py" From 960078404255f6b439a13b1929a2d3aec7d507e5 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Mon, 8 Dec 2025 17:01:55 +0800 Subject: [PATCH 5/6] add draw --- docker/end_to_end/draw_end2end.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 docker/end_to_end/draw_end2end.py diff --git a/docker/end_to_end/draw_end2end.py b/docker/end_to_end/draw_end2end.py new file mode 100644 index 0000000..102ea0c --- /dev/null +++ b/docker/end_to_end/draw_end2end.py @@ -0,0 +1,29 @@ +import matplotlib.pyplot as plt + +# 原始数据 +length_labels = [256, 512, 1024, 2048, 4096] # 用来当 x 轴刻度标签 +x = list(range(len(length_labels))) # 实际用于绘图的 x 坐标:0,1,2 + +torch_times = [14.92, 28.43, 55.59] +triton_times = [5.93, 13.09, 32.74, 76.94, 202.10] +txl_times = [8.28, 10.53, 22.13, 45.38, 102.71] + +plt.figure(figsize=(6, 4)) + +# 用 0,1,2 作为横坐标,这样间隔一定相同 +plt.plot(list(range(len(torch_times))), torch_times, marker='o', label='PyTorch') +plt.plot(list(range(len(triton_times))), triton_times, marker='s', label='Triton') +plt.plot(list(range(len(txl_times))), txl_times, marker='^', label='Txl') + +plt.xlabel('Context Length') +plt.ylabel('Latency (ms)') +plt.title('GPT2-Style End-to-End Inference Latency Comparison') + +# 把 0,1,2 映射成 256,512,1024 +plt.xticks(x, length_labels) + +plt.grid(True, linestyle='--', linewidth=0.5) +plt.legend() +plt.tight_layout() +plt.show() + From 603b811a17c3259a51974a4cab270e6d54156119 Mon Sep 17 00:00:00 2001 From: Qi_Zi <1156915330@qq.com> Date: Tue, 9 Dec 2025 10:20:11 +0800 Subject: [PATCH 6/6] update torch performance --- docker/end_to_end/bench.py | 53 ++++++++++++++++++++++++------- docker/end_to_end/draw_end2end.py | 2 +- docker/end_to_end/gpt.py | 12 +++++-- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/docker/end_to_end/bench.py b/docker/end_to_end/bench.py index 7e4d707..0616399 100644 --- a/docker/end_to_end/bench.py +++ b/docker/end_to_end/bench.py @@ -2,6 +2,7 @@ import torch import triton +import time from transformers import AutoTokenizer from gpt import convert_hf_and_load_model, load_model_only @@ -78,6 +79,38 @@ def fn(): fn = lambda: txl_model(inputs["input_ids"]) return triton.testing.do_bench(fn, warmup=warmup, rep=rep) +def run_benchmark_all(warmup=25, rep=100, mixed_precison=True, max_length=1024): + assert torch.cuda.is_available() + device = "cuda" + model_id = "gpt2" + model, hf_model, txl_model = convert_hf_and_load_model(model_id, device) + if mixed_precison: + model.to(torch.float16) + txl_model.to(torch.float16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer([STRING*1000] * 32, return_tensors="pt", max_length=max_length, truncation=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + print("input_ids shape:", inputs["input_ids"].shape) + with torch.no_grad(): + def fn(): + if mixed_precison: + with torch.autocast(device_type="cuda", dtype=torch.float16): + return hf_model(**inputs).last_hidden_state + else: + return hf_model(**inputs).last_hidden_state + torch_time = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + torch.cuda.synchronize() + time.sleep(1) + fn = lambda: model(inputs["input_ids"]) + triton_time = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + torch.cuda.synchronize() + time.sleep(1) + fn = lambda: txl_model(inputs["input_ids"]) + txl_time = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + torch.cuda.synchronize() + time.sleep(1) + return torch_time, triton_time, txl_time + def validate(mixed_precison=True, max_length=1024): assert torch.cuda.is_available() device = "cuda" @@ -113,26 +146,24 @@ def validate(mixed_precison=True, max_length=1024): # triton: batch_size = 2048 && t = 3153.70 print("start") -validate(mixed_precison=True, max_length=1024) +# validate(mixed_precison=True, max_length=1024) # print("triton:", run_benchmark("triton")) # print("txl:", run_benchmark("txl")) # print("torch:", run_benchmark("torch")) # print("triton long:", run_benchmark("triton", max_length=4096, need_gpt=False)) # print("txl long:", run_benchmark("txl", max_length=4096, need_gpt=False)) -for i in range(8, 11): +for i in range(8, 13): length = 2**i - torch_time = run_benchmark("torch", max_length=length) - triton_time = run_benchmark("triton", max_length=length) - txl_time = run_benchmark("txl", max_length=length) + torch_time , triton_time, txl_time = run_benchmark_all(max_length=length) speedup_triton = torch_time / triton_time speedup_txl = torch_time / txl_time print(f"length: {length}, torch: {torch_time:.2f} ms, triton: {triton_time:.2f} ms, txl: {txl_time:.2f} ms, speedup_triton: {speedup_triton:.2f}x, speedup_txl: {speedup_txl:.2f}x") -for i in range(11, 13): - length = 2**i - triton_time = run_benchmark("triton", max_length=length, need_gpt=False) - txl_time = run_benchmark("txl", max_length=length, need_gpt=False) - speedup_txl = triton_time / txl_time - print(f"length: {length}, triton: {triton_time:.2f} ms, txl: {txl_time:.2f} ms, speedup_txl: {speedup_txl:.2f}x") +# for i in range(11, 13): +# length = 2**i +# triton_time = run_benchmark("triton", max_length=length, need_gpt=False) +# txl_time = run_benchmark("txl", max_length=length, need_gpt=False) +# speedup_txl = triton_time / txl_time +# print(f"length: {length}, triton: {triton_time:.2f} ms, txl: {txl_time:.2f} ms, speedup_txl: {speedup_txl:.2f}x") # print("txl test:", run_benchmark("txl", max_length=1024, need_gpt=False)) # print("triton test:", run_benchmark("triton", max_length=1024, need_gpt=False)) # OLD SUMMARY diff --git a/docker/end_to_end/draw_end2end.py b/docker/end_to_end/draw_end2end.py index 102ea0c..f9f6eca 100644 --- a/docker/end_to_end/draw_end2end.py +++ b/docker/end_to_end/draw_end2end.py @@ -4,7 +4,7 @@ length_labels = [256, 512, 1024, 2048, 4096] # 用来当 x 轴刻度标签 x = list(range(len(length_labels))) # 实际用于绘图的 x 坐标:0,1,2 -torch_times = [14.92, 28.43, 55.59] +torch_times = [14.92, 28.43, 55.59, 113.88, 240.12] triton_times = [5.93, 13.09, 32.74, 76.94, 202.10] txl_times = [8.28, 10.53, 22.13, 45.38, 102.71] diff --git a/docker/end_to_end/gpt.py b/docker/end_to_end/gpt.py index 90e2dda..9426f07 100644 --- a/docker/end_to_end/gpt.py +++ b/docker/end_to_end/gpt.py @@ -6,7 +6,7 @@ import torch.nn as nn from tqdm.auto import tqdm from transformers import AutoTokenizer -from transformers import GPT2Model as HFGPT2 +from transformers import GPT2Model, GPT2Config from kernels import (flash_attention_v1, fused_embeddings, fused_ffn, fused_layer_norm, matmul_and_split_qkv) @@ -339,7 +339,15 @@ def convert_huggingface_to_triton(hf_sd, hf_config): def convert_hf_and_load_model(model_id, device): - hf_model = HFGPT2.from_pretrained(model_id) + hf_config = GPT2Config( + n_positions=16384, + n_ctx=16384, + n_layer=12, + n_head=12, + n_embd=768, + vocab_size=50304, + ) + hf_model = GPT2Model(hf_config) state_dict, config = convert_huggingface_to_triton( hf_model.state_dict(), hf_model.config )