Skip to content

Commit 02ccc8d

Browse files
Refactor CUDA graph padding logic
- Move padding calculation before CUDA graph dispatch - Update dispatch() to take uniform_decode directly instead of computing it - Remove max_num_scheduled_tokens parameter from dispatch() - Update BatchDescriptor to use 'uniform' field consistently - Fix _prepare_inputs to handle new padding flow - Update attention backends to work with new padding approach - Add documentation for BatchDescriptor fields Co-authored-by: ayushsatyam146 <ayushsatyam146@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 11857a0 commit 02ccc8d

File tree

11 files changed

+5794
-165
lines changed

11 files changed

+5794
-165
lines changed

docs/design/cuda_graphs.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
8484
```python
8585
class BatchDescriptor(NamedTuple):
8686
num_tokens: int
87-
uniform_decode: bool = False
87+
num_reqs: int
88+
uniform: bool = False
89+
has_lora: bool = False
8890
```
8991

90-
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
92+
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
9193

92-
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
94+
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
9395

9496
!!! note
9597
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
{
2+
"results": {
3+
"gsm8k": {
4+
"alias": "gsm8k",
5+
"exact_match,strict-match": 0.756633813495072,
6+
"exact_match_stderr,strict-match": 0.011819940385701125,
7+
"exact_match,flexible-extract": 0.755117513267627,
8+
"exact_match_stderr,flexible-extract": 0.011844819027863667
9+
}
10+
},
11+
"group_subtasks": {
12+
"gsm8k": []
13+
},
14+
"configs": {
15+
"gsm8k": {
16+
"task": "gsm8k",
17+
"tag": [
18+
"math_word_problems"
19+
],
20+
"dataset_path": "gsm8k",
21+
"dataset_name": "main",
22+
"training_split": "train",
23+
"test_split": "test",
24+
"fewshot_split": "train",
25+
"doc_to_text": "Question: {{question}}\nAnswer:",
26+
"doc_to_target": "{{answer}}",
27+
"unsafe_code": false,
28+
"description": "",
29+
"target_delimiter": " ",
30+
"fewshot_delimiter": "\n\n",
31+
"num_fewshot": 5,
32+
"metric_list": [
33+
{
34+
"metric": "exact_match",
35+
"aggregation": "mean",
36+
"higher_is_better": true,
37+
"ignore_case": true,
38+
"ignore_punctuation": false,
39+
"regexes_to_ignore": [
40+
",",
41+
"\\$",
42+
"(?s).*#### ",
43+
"\\.$"
44+
]
45+
}
46+
],
47+
"output_type": "generate_until",
48+
"generation_kwargs": {
49+
"until": [
50+
"Question:",
51+
"</s>",
52+
"<|im_end|>"
53+
],
54+
"do_sample": false,
55+
"temperature": 0.0
56+
},
57+
"repeats": 1,
58+
"filter_list": [
59+
{
60+
"name": "strict-match",
61+
"filter": [
62+
{
63+
"function": "regex",
64+
"regex_pattern": "#### (\\-?[0-9\\.\\,]+)"
65+
},
66+
{
67+
"function": "take_first"
68+
}
69+
]
70+
},
71+
{
72+
"name": "flexible-extract",
73+
"filter": [
74+
{
75+
"function": "regex",
76+
"group_select": -1,
77+
"regex_pattern": "(-?[$0-9.,]{2,})|(-?[0-9]+)"
78+
},
79+
{
80+
"function": "take_first"
81+
}
82+
]
83+
}
84+
],
85+
"should_decontaminate": false,
86+
"metadata": {
87+
"version": 3.0,
88+
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
89+
"base_url": "http://localhost:3333/v1/completions",
90+
"num_concurrent": 256
91+
}
92+
}
93+
},
94+
"versions": {
95+
"gsm8k": 3.0
96+
},
97+
"n-shot": {
98+
"gsm8k": 5
99+
},
100+
"higher_is_better": {
101+
"gsm8k": {
102+
"exact_match": true
103+
}
104+
},
105+
"n-samples": {
106+
"gsm8k": {
107+
"original": 1319,
108+
"effective": 1319
109+
}
110+
},
111+
"config": {
112+
"model": "local-completions",
113+
"model_args": "model=meta-llama/Meta-Llama-3-8B-Instruct,base_url=http://localhost:3333/v1/completions,num_concurrent=256",
114+
"batch_size": "auto",
115+
"batch_sizes": [],
116+
"device": null,
117+
"use_cache": null,
118+
"limit": null,
119+
"bootstrap_iters": 100000,
120+
"gen_kwargs": null,
121+
"random_seed": 0,
122+
"numpy_seed": 1234,
123+
"torch_seed": 1234,
124+
"fewshot_seed": 1234
125+
},
126+
"git_hash": "v0.11.0rc1-1437-g6160f1ce0",
127+
"date": 1762924461.0959744,
128+
"pretty_env_info": "PyTorch version: 2.9.0+cu128\nIs debug build: False\nCUDA used to build PyTorch: 12.8\nROCM used to build PyTorch: N/A\n\nOS: CentOS Stream 9 (x86_64)\nGCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-11)\nClang version: Could not collect\nCMake version: version 4.1.0\nLibc version: glibc-2.34\n\nPython version: 3.12.11 (main, Aug 14 2025, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-11)] (64-bit runtime)\nPython platform: Linux-5.14.0-620.el9.x86_64-x86_64-with-glibc2.34\nIs CUDA available: True\nCUDA runtime version: 12.9.86\nCUDA_MODULE_LOADING set to: \nGPU models and configuration: \nGPU 0: NVIDIA H100 80GB HBM3\nGPU 1: NVIDIA H100 80GB HBM3\nGPU 2: NVIDIA H100 80GB HBM3\nGPU 3: NVIDIA H100 80GB HBM3\nGPU 4: NVIDIA H100 80GB HBM3\nGPU 5: NVIDIA H100 80GB HBM3\nGPU 6: NVIDIA H100 80GB HBM3\nGPU 7: NVIDIA H100 80GB HBM3\n\nNvidia driver version: 580.95.05\ncuDNN version: Could not collect\nIs XPU available: False\nHIP runtime version: N/A\nMIOpen runtime version: N/A\nIs XNNPACK available: True\n\nCPU:\nArchitecture: x86_64\nCPU op-mode(s): 32-bit, 64-bit\nAddress sizes: 46 bits physical, 57 bits virtual\nByte Order: Little Endian\nCPU(s): 160\nOn-line CPU(s) list: 0-159\nVendor ID: GenuineIntel\nModel name: Intel Xeon Processor (SapphireRapids)\nCPU family: 6\nModel: 143\nThread(s) per core: 2\nCore(s) per socket: 40\nSocket(s): 2\nStepping: 4\nBogoMIPS: 4200.00\nFlags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities\nVirtualization: VT-x\nHypervisor vendor: KVM\nVirtualization type: full\nL1d cache: 5 MiB (160 instances)\nL1i cache: 5 MiB (160 instances)\nL2 cache: 320 MiB (80 instances)\nL3 cache: 32 MiB (2 instances)\nNUMA node(s): 2\nNUMA node0 CPU(s): 0-79\nNUMA node1 CPU(s): 80-159\nVulnerability Gather data sampling: Not affected\nVulnerability Indirect target selection: Mitigation; Aligned branch/return thunks\nVulnerability Itlb multihit: Not affected\nVulnerability L1tf: Not affected\nVulnerability Mds: Not affected\nVulnerability Meltdown: Not affected\nVulnerability Mmio stale data: Unknown: No mitigations\nVulnerability Reg file data sampling: Not affected\nVulnerability Retbleed: Not affected\nVulnerability Spec rstack overflow: Not affected\nVulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl\nVulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization\nVulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop\nVulnerability Srbds: Not affected\nVulnerability Tsa: Not affected\nVulnerability Tsx async abort: Not affected\n\nVersions of relevant libraries:\n[pip3] Could not collect\n[conda] Could not collect",
129+
"transformers_version": "4.56.2",
130+
"lm_eval_version": "0.4.9.1",
131+
"upper_git_hash": null,
132+
"tokenizer_pad_token": [
133+
"<|eot_id|>",
134+
"128009"
135+
],
136+
"tokenizer_eos_token": [
137+
"<|eot_id|>",
138+
"128009"
139+
],
140+
"tokenizer_bos_token": [
141+
"<|begin_of_text|>",
142+
"128000"
143+
],
144+
"eot_token_id": 128009,
145+
"max_length": 2047,
146+
"task_hashes": {
147+
"gsm8k": "2330f4ebfcccaf66a892922df2819cdb1f118e448d076d3f42bdde4177678ac7"
148+
},
149+
"model_source": "local-completions",
150+
"model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
151+
"model_name_sanitized": "meta-llama__Meta-Llama-3-8B-Instruct",
152+
"system_instruction": null,
153+
"system_instruction_sha": null,
154+
"fewshot_as_multiturn": false,
155+
"chat_template": null,
156+
"chat_template_sha": null,
157+
"start_time": 3581101.19808223,
158+
"end_time": 3581147.328442393,
159+
"total_evaluation_time_seconds": "46.13036016281694"
160+
}

0 commit comments

Comments
 (0)