Skip to content

Commit c8ab988

Browse files
[BugFix] Fix DBO assert assert B_block_table == B_q (#29933)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 48a5fff commit c8ab988

File tree

5 files changed

+83
-63
lines changed

5 files changed

+83
-63
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
split_attn_metadata,
1414
split_decodes_and_prefills,
1515
)
16-
from vllm.v1.worker.ubatch_utils import create_ubatch_slices
16+
from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
1717

1818

1919
@pytest.fixture
@@ -294,8 +294,14 @@ def test_prefill_split_across_ubatches(
294294
qsl_np = common.query_start_loc_cpu.numpy()
295295
num_tokens = common.num_actual_tokens
296296

297-
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point)
298-
assert len(ubatch_slices) == 2
297+
ubatch_slices, _ = maybe_create_ubatch_slices(
298+
True,
299+
num_scheduled_tokens,
300+
num_tokens,
301+
batch_spec.batch_size,
302+
split_point=split_point,
303+
)
304+
assert ubatch_slices is not None and len(ubatch_slices) == 2
299305

300306
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
301307
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)

vllm/v1/spec_decode/eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,7 +1258,7 @@ def _pad_batch_across_dp(
12581258
num_tokens_padded: int,
12591259
) -> tuple[int, torch.Tensor]:
12601260
# TODO(Flechman): support DBO ubatching
1261-
ubatch_slices, num_toks_across_dp = coordinate_batch_across_dp(
1261+
should_ubatch, num_toks_across_dp = coordinate_batch_across_dp(
12621262
num_tokens_unpadded=num_tokens_unpadded,
12631263
parallel_config=self.vllm_config.parallel_config,
12641264
allow_microbatching=False,
@@ -1267,7 +1267,7 @@ def _pad_batch_across_dp(
12671267
uniform_decode=None,
12681268
num_scheduled_tokens_per_request=None,
12691269
)
1270-
assert ubatch_slices is None, "DBO ubatching not implemented for EAGLE"
1270+
assert not should_ubatch, "DBO ubatching not implemented for EAGLE"
12711271

12721272
num_tokens_dp_padded = num_tokens_padded
12731273
if num_toks_across_dp is not None:

vllm/v1/worker/dp_utils.py

Lines changed: 4 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
45
import numpy as np
56
import torch
67
import torch.distributed as dist
@@ -9,10 +10,7 @@
910
from vllm.distributed.parallel_state import get_dp_group
1011
from vllm.logger import init_logger
1112
from vllm.v1.worker.ubatch_utils import (
12-
UBatchSlice,
13-
UBatchSlices,
1413
check_ubatch_thresholds,
15-
create_ubatch_slices,
1614
is_second_ubatch_empty,
1715
)
1816

@@ -91,20 +89,6 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
9189
return num_tokens_across_dp.cpu()
9290

9391

94-
# This just pads the second ubatch slice out to the total number of tokens
95-
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
96-
def _pad_out_ubatch_slice(
97-
ubatch_slices: UBatchSlices, num_total_tokens: int
98-
) -> UBatchSlices:
99-
padded_second_token_slice = slice(
100-
ubatch_slices[1].token_slice.start, num_total_tokens
101-
)
102-
ubatch_slices[1] = UBatchSlice(
103-
ubatch_slices[1].request_slice, padded_second_token_slice
104-
)
105-
return ubatch_slices
106-
107-
10892
def _synchronize_dp_ranks(
10993
num_tokens_unpadded: int,
11094
num_tokens_padded: int,
@@ -175,7 +159,7 @@ def coordinate_batch_across_dp(
175159
num_tokens_padded: int | None = None,
176160
uniform_decode: bool | None = None,
177161
num_scheduled_tokens_per_request: np.ndarray | None = None,
178-
) -> tuple[UBatchSlices | None, torch.Tensor | None]:
162+
) -> tuple[bool, torch.Tensor | None]:
179163
"""
180164
Coordinates amongst all DP ranks to determine if and how the full batch
181165
should be split into microbatches.
@@ -204,7 +188,7 @@ def coordinate_batch_across_dp(
204188
"""
205189
if parallel_config.data_parallel_size == 1:
206190
# Early exit.
207-
return None, None
191+
return False, None
208192

209193
# If the caller has explicitly enabled microbatching.
210194
should_attempt_ubatching = False
@@ -228,23 +212,4 @@ def coordinate_batch_across_dp(
228212
parallel_config,
229213
)
230214

231-
# Don't microbatch unless every other DP worker is also microbatching
232-
if not should_ubatch:
233-
return (None, num_tokens_after_padding)
234-
235-
# This doesn't actually pad the ubatch slices. It just initializes the
236-
# split point to the padded value so that padding can be applied
237-
# to the second ubatch in pad_out_ubatch_slice after attention
238-
# metadata creation
239-
assert num_tokens_after_padding is not None
240-
num_tokens_padded = int(num_tokens_after_padding[0].item())
241-
token_split_point = int(num_tokens_padded) // 2
242-
243-
assert num_scheduled_tokens_per_request is not None
244-
ubatch_slices = create_ubatch_slices(
245-
num_scheduled_tokens_per_request, token_split_point
246-
)
247-
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
248-
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
249-
250-
return (ubatch_slices, num_tokens_after_padding)
215+
return (should_ubatch, num_tokens_after_padding)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
from vllm.v1.worker.ubatch_utils import (
154154
UBatchSlices,
155155
check_ubatch_thresholds,
156+
maybe_create_ubatch_slices,
156157
)
157158
from vllm.v1.worker.utils import is_residual_scattered_for_sp
158159

@@ -2743,7 +2744,7 @@ def _determine_batch_execution_and_padding(
27432744
) -> tuple[
27442745
CUDAGraphMode,
27452746
BatchDescriptor,
2746-
UBatchSlices | None,
2747+
bool,
27472748
torch.Tensor | None,
27482749
CUDAGraphStat | None,
27492750
]:
@@ -2779,7 +2780,7 @@ def _determine_batch_execution_and_padding(
27792780

27802781
# Extra coordination when running data-parallel since we need to coordinate
27812782
# across ranks
2782-
ubatch_slices, num_tokens_across_dp = None, None
2783+
should_ubatch, num_tokens_across_dp = False, None
27832784
if self.vllm_config.parallel_config.data_parallel_size > 1:
27842785
# Disable DP padding when running eager to avoid excessive padding when
27852786
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
@@ -2789,8 +2790,8 @@ def _determine_batch_execution_and_padding(
27892790
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
27902791
)
27912792

2792-
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
2793-
num_tokens_unpadded=num_tokens_padded,
2793+
should_ubatch, num_tokens_across_dp = coordinate_batch_across_dp(
2794+
num_tokens_unpadded=num_tokens,
27942795
parallel_config=self.parallel_config,
27952796
allow_microbatching=allow_microbatching,
27962797
allow_dp_padding=allow_dp_padding,
@@ -2822,7 +2823,7 @@ def _determine_batch_execution_and_padding(
28222823
return (
28232824
cudagraph_mode,
28242825
batch_descriptor,
2825-
ubatch_slices,
2826+
should_ubatch,
28262827
num_tokens_across_dp,
28272828
cudagraph_stats,
28282829
)
@@ -2921,7 +2922,7 @@ def execute_model(
29212922
(
29222923
cudagraph_mode,
29232924
batch_desc,
2924-
ubatch_slices,
2925+
should_ubatch,
29252926
num_tokens_across_dp,
29262927
cudagraph_stats,
29272928
) = self._determine_batch_execution_and_padding(
@@ -2934,29 +2935,37 @@ def execute_model(
29342935

29352936
logger.debug(
29362937
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
2937-
"ubatch_slices: %s, num_tokens_across_dp: %s",
2938+
"should_ubatch: %s, num_tokens_across_dp: %s",
29382939
cudagraph_mode,
29392940
batch_desc,
2940-
ubatch_slices,
2941+
should_ubatch,
29412942
num_tokens_across_dp,
29422943
)
29432944

29442945
num_tokens_padded = batch_desc.num_tokens
29452946
num_reqs_padded = (
29462947
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
29472948
)
2949+
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
2950+
should_ubatch,
2951+
num_scheduled_tokens_np,
2952+
num_tokens_padded,
2953+
num_reqs_padded,
2954+
)
29482955

2949-
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
29502956
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
29512957

2958+
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
2959+
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
2960+
29522961
(attn_metadata, spec_decode_common_attn_metadata) = (
29532962
self._build_attention_metadata(
29542963
num_tokens=num_tokens_unpadded,
29552964
num_tokens_padded=num_tokens_padded if pad_attn else None,
29562965
num_reqs=num_reqs,
29572966
num_reqs_padded=num_reqs_padded if pad_attn else None,
29582967
max_query_len=max_num_scheduled_tokens,
2959-
ubatch_slices=ubatch_slices,
2968+
ubatch_slices=ubatch_slices_attn,
29602969
logits_indices=logits_indices,
29612970
use_spec_decode=use_spec_decode,
29622971
num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
@@ -2993,7 +3002,7 @@ def execute_model(
29933002
num_tokens_across_dp=num_tokens_across_dp,
29943003
cudagraph_runtime_mode=cudagraph_mode,
29953004
batch_descriptor=batch_desc,
2996-
ubatch_slices=ubatch_slices,
3005+
ubatch_slices=ubatch_slices_padded,
29973006
),
29983007
record_function_or_nullcontext("gpu_model_runner: forward"),
29993008
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
@@ -3945,7 +3954,7 @@ def _dummy_run(
39453954

39463955
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
39473956

3948-
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
3957+
_cudagraph_mode, batch_desc, should_ubatch, num_tokens_across_dp, _ = (
39493958
self._determine_batch_execution_and_padding(
39503959
num_tokens=num_tokens_unpadded,
39513960
num_reqs=num_reqs,
@@ -3979,6 +3988,9 @@ def _dummy_run(
39793988
num_reqs_padded = (
39803989
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
39813990
)
3991+
ubatch_slices, ubatch_slices_padded = maybe_create_ubatch_slices(
3992+
should_ubatch, num_scheduled_tokens, num_tokens_padded, num_reqs_padded
3993+
)
39823994

39833995
attn_metadata: PerLayerAttnMetadata | None = None
39843996

@@ -4000,11 +4012,12 @@ def _dummy_run(
40004012
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
40014013
self.query_start_loc.copy_to_gpu()
40024014

4015+
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
40034016
attn_metadata, _ = self._build_attention_metadata(
40044017
num_tokens=num_tokens_unpadded,
40054018
num_reqs=num_reqs_padded,
40064019
max_query_len=max_query_len,
4007-
ubatch_slices=ubatch_slices,
4020+
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
40084021
for_cudagraph_capture=is_graph_capturing,
40094022
)
40104023

@@ -4056,11 +4069,11 @@ def _dummy_run(
40564069
num_tokens_padded, None, False
40574070
)
40584071

4059-
if ubatch_slices is not None:
4072+
if ubatch_slices_padded is not None:
40604073
# Adjust values to reflect a single ubatch.
40614074
# TODO(sage,lucas): this is cruft that should be addressed in
40624075
# the padding refactor.
4063-
num_tokens_padded = ubatch_slices[0].num_tokens
4076+
num_tokens_padded = ubatch_slices_padded[0].num_tokens
40644077
if num_tokens_across_dp is not None:
40654078
num_tokens_across_dp[:] = num_tokens_padded
40664079

@@ -4073,7 +4086,7 @@ def _dummy_run(
40734086
num_tokens_across_dp=num_tokens_across_dp,
40744087
cudagraph_runtime_mode=cudagraph_runtime_mode,
40754088
batch_descriptor=batch_desc,
4076-
ubatch_slices=ubatch_slices,
4089+
ubatch_slices=ubatch_slices_padded,
40774090
),
40784091
):
40794092
outputs = self.model(

vllm/v1/worker/ubatch_utils.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,37 @@ def check_ubatch_thresholds(
4242
return num_tokens >= config.dbo_prefill_token_threshold
4343

4444

45-
def create_ubatch_slices(
46-
num_scheduled_tokens: np.ndarray, split_point: int
45+
# This just pads the second ubatch slice out to the total number of tokens
46+
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
47+
def _pad_out_ubatch_slices(
48+
ubatch_slices: UBatchSlices, num_total_tokens: int, num_reqs_padded: int
4749
) -> UBatchSlices:
50+
# TODO(lucas): handle empty second ubatch
51+
padded_second_request_slice = slice(
52+
ubatch_slices[1].request_slice.start, num_reqs_padded
53+
)
54+
padded_second_token_slice = slice(
55+
ubatch_slices[1].token_slice.start, num_total_tokens
56+
)
57+
return [
58+
ubatch_slices[0],
59+
UBatchSlice(padded_second_request_slice, padded_second_token_slice),
60+
]
61+
62+
63+
def maybe_create_ubatch_slices(
64+
should_ubatch: bool,
65+
num_scheduled_tokens: np.ndarray,
66+
num_tokens_padded: int,
67+
num_reqs_padded: int,
68+
split_point: int | None = None,
69+
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
70+
if not should_ubatch:
71+
return None, None
72+
73+
if split_point is None:
74+
split_point = int(num_tokens_padded) // 2
75+
4876
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
4977
# in cu_num_tokens directly (i.e. query_start_loc)
5078
cu_num_tokens = np.zeros(len(num_scheduled_tokens) + 1, dtype=np.int32)
@@ -67,7 +95,15 @@ def create_ubatch_slices(
6795
)
6896
second_ubatch_req_slice = slice(second_ubatch_req_start, len(cu_num_tokens) - 1)
6997

70-
return [
98+
ubatch_slices = [
7199
UBatchSlice(first_ubatch_req_slice, first_ubatch_token_slice),
72100
UBatchSlice(second_ubatch_req_slice, second_ubatch_token_slice),
73101
]
102+
103+
ubatch_slices_padded = _pad_out_ubatch_slices(
104+
ubatch_slices, num_tokens_padded, num_reqs_padded
105+
)
106+
107+
assert sum(s.num_tokens for s in ubatch_slices_padded) == num_tokens_padded
108+
109+
return ubatch_slices, ubatch_slices_padded

0 commit comments

Comments
 (0)