Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,7 @@ tilelang_optimization_analysis.md
boundary_check_comparison.md
GITHUB_ISSUE.md
Tilelang-failed_test_cases/
# Benchmark results
benchmark_results/
# Cursor IDE files
.cursor/
7 changes: 7 additions & 0 deletions diffulex/engine/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
return seq.seq_id

def step(self):
# Clear step-local activation quant cache (W8A8/W4A8, etc.) so we only reuse within a single step.
try:
from diffulex.utils.quantization.context import clear_act_quant_cache
clear_act_quant_cache()
except Exception:
# Quantization context may not be initialized in some paths; ignore.
pass
seqs, is_prefill = self.scheduler.schedule()
sample_output = self.model_runner.call("run", seqs, is_prefill)
n_diff_steps = self.scheduler.postprocess(seqs, sample_output)
Expand Down
36 changes: 33 additions & 3 deletions diffulex/strategy/d2f/engine/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,38 @@ class D2FKVCacheManager(KVCacheManagerBase):
def __init__(self, config: Config):
super().__init__(config)

def _required_kv_blocks(self, seq: "D2FSequence") -> int:
"""How many KV-cache blocks this sequence needs *now* for cached+to-cache tokens.

NOTE: In diffusion decoding, a single decode step may move multiple tokens into
"to_cache", which can cross multiple KV blocks. So we must ensure block_table
is large enough for all cached_or_caching tokens, not just append one block.
"""
n = seq.cached_or_caching_num_tokens
if n <= 0:
return 0
# Need enough blocks to cover token indices [0, n-1].
return (n + self.block_size - 1) // self.block_size

def can_append(self, seq: "D2FSequence") -> bool:
return len(self.free_block_ids) >= (seq.cached_or_caching_num_tokens % self.block_size == 1)
# We may need to allocate multiple blocks in one step (cached_or_caching can jump).
required = self._required_kv_blocks(seq)
missing = max(0, required - len(seq.block_table))
return len(self.free_block_ids) >= missing

def may_append(self, seq: "D2FSequence") -> None:
if seq.cached_or_caching_num_tokens == 0:
return
block_table = seq.block_table
if not block_table:
# Defensive: allocate() should have populated it for prefill/prompt, but don't crash here.
return
last_block = self.blocks[block_table[-1]]
if seq.cached_or_caching_num_tokens // self.block_size == len(seq.block_table):

required = self._required_kv_blocks(seq)
# Allocate enough KV blocks to cover all cached_or_caching tokens.
while len(block_table) < required:
last_block = self.blocks[block_table[-1]]
# Preserve the existing "finalize previous block hash" behavior before moving on.
if last_block.hash == -1:
prev_end_token = seq.cached_or_caching_num_tokens - seq.caching_num_tokens - 1
prev_block_idx = prev_end_token // self.block_size
Expand All @@ -34,6 +55,15 @@ def may_append(self, seq: "D2FSequence") -> None:
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id

if not self.free_block_ids:
raise RuntimeError(
"D2FKVCacheManager: insufficient free KV cache blocks to append: "
f"required={required}, current_len={len(block_table)}, "
f"cached_or_caching_num_tokens={seq.cached_or_caching_num_tokens}, "
f"block_size={self.block_size}."
)

block_id = self.free_block_ids[0]
self._allocate_block(block_id)
block_table.append(block_id)
28 changes: 21 additions & 7 deletions diffulex/strategy/d2f/engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,21 @@ def get_step(diff_blk, begin_idx):
cur_diffusion_block_start = 0
cur_diffusion_block_end = step
start_idx += step
# IMPORTANT:
# We must have a KV-cache block allocated for this mem_block_idx.
# If not, this is almost always due to insufficient KV cache blocks
# (e.g. higher model/weight memory footprint leaves too few blocks).
if mem_block_idx >= len(seq.block_table):
raise RuntimeError(
"KV cache block allocation is insufficient during decode: "
f"mem_block_idx={mem_block_idx} requires block_table length >= {mem_block_idx + 1}, "
f"but got len(block_table)={len(seq.block_table)} (seq.num_blocks={seq.num_blocks}). "
"This usually means GPU memory utilization is too low to allocate enough KV cache "
f"blocks for this run (num_kvcache_blocks={getattr(self.config, 'num_kvcache_blocks', None)}, "
f"gpu_memory_utilization={getattr(self.config, 'gpu_memory_utilization', None)}). "
"Try increasing gpu_memory_utilization, reducing max_model_len/max_tokens/max_num_seqs, "
"or using a lower-memory weight quantization (e.g. int4)."
)
mem_block_start = (
seq.block_table[mem_block_idx] * self.block_size
+ context_len % seq.block_size
Expand Down Expand Up @@ -246,13 +261,12 @@ def get_step(diff_blk, begin_idx):
context_lens_tensor = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
block_tables = self.prepare_block_tables(seqs)
# NOTE:
# - d2f decode currently uses "varlen" mode by default.
# - When kv_cache_dtype is FP8, "varlen" decode falls back to Python dequantization via
# `load_kvcache`, which can materialize large intermediate tensors and often makes FP8
# KV *slower* than BF16.
# - Prefer TileLang's BF16Q+FP8KV decode kernel path by switching to "static" mode when
# FP8 KV is enabled.
# - Allow manual override via config.decode_mode if specified
# - d2f decode supports "varlen" and "static" modes (see config.decode_mode).
# - For FP8 KV, the (varlen/distinct-layout) path uses `load_kvcache` which is expected to
# handle FP8 dequantization / scale application inside the fused operator (no Python-level dequant).
# - Performance can still differ between modes/kernels; when FP8 KV is enabled, prefer the
# best-supported kernel path on your stack (often "static"/unified-layout) and validate with profiling.
# - Allow manual override via config.decode_mode if specified.
decode_mode = self._get_decode_mode()
set_d2f_attn_metadata(
False,
Expand Down
45 changes: 45 additions & 0 deletions diffulex/utils/quantization/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class QuantizationContext:

def __init__(self):
self._strategies: Dict[str, QuantizationStrategy] = {}
# Step-local cache for activation quantization (e.g., W8A8 per-row quant).
# Keyed by tensor identity+layout to allow reuse within a single forward/step.
self._act_quant_cache: Dict[tuple, tuple] = {}

@classmethod
def current(cls) -> 'QuantizationContext':
Expand Down Expand Up @@ -86,6 +89,33 @@ def get_linear_strategy(self, kind: str) -> Optional[LinearQuantizationStrategy]
def clear(self):
"""Clear all strategies."""
self._strategies.clear()
self._act_quant_cache.clear()

# ---- Activation quantization cache helpers (step-local) ----
def clear_act_quant_cache(self) -> None:
self._act_quant_cache.clear()

def _act_quant_cache_key(self, x) -> tuple:
# Include version to avoid reusing after in-place mutation.
# data_ptr() is stable for the tensor storage; combine with shape/stride/dtype/device.
try:
version = getattr(x, "_version", None)
except Exception:
version = None
return (
int(x.data_ptr()),
tuple(x.shape),
tuple(x.stride()),
str(x.dtype),
str(x.device),
int(version) if version is not None else -1,
)

def get_cached_act_quant(self, x):
return self._act_quant_cache.get(self._act_quant_cache_key(x))

def set_cached_act_quant(self, x, x_q, x_scales) -> None:
self._act_quant_cache[self._act_quant_cache_key(x)] = (x_q, x_scales)

def __enter__(self):
return self
Expand Down Expand Up @@ -136,3 +166,18 @@ def get_linear_strategy(kind: str) -> Optional[LinearQuantizationStrategy]:
ctx = QuantizationContext.current()
return ctx.get_linear_strategy(kind)


def clear_act_quant_cache() -> None:
"""Clear step-local activation quant cache for the current thread."""
QuantizationContext.current().clear_act_quant_cache()


def get_cached_act_quant(x):
"""Get cached (x_q, x_scales) for activation quantization, or None."""
return QuantizationContext.current().get_cached_act_quant(x)


def set_cached_act_quant(x, x_q, x_scales) -> None:
"""Set cached (x_q, x_scales) for activation quantization."""
QuantizationContext.current().set_cached_act_quant(x, x_q, x_scales)

8 changes: 6 additions & 2 deletions diffulex/utils/quantization/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,15 @@ def _normalize_linear_dtype(dtype: str) -> str:
"gptq": "gptq",
"awq": "awq",
"gptq_awq": "gptq_awq",
# vLLM-style fused W8A16 path (Diffulex vendored): user-facing alias "marlin"
# Normalized key is "marlin_int8" to avoid conflating with other quant methods.
"marlin": "marlin_int8",
"marlin_int8": "marlin_int8",
}
if s not in aliases:
raise ValueError(
f"Unsupported linear quant dtype={dtype!r}. "
"Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq"
"Supported: bf16/int8/int4/fp8/fp8_e4m3/fp8_e5m2/gptq/awq/marlin"
)
return aliases[s]

Expand Down Expand Up @@ -146,6 +150,6 @@ def create_linear_strategy(*, weight_dtype: str, act_dtype: str) -> LinearQuanti
def registered_linear_dtypes() -> list[str]:
"""Return the normalized dtype/method names accepted by `_normalize_linear_dtype`."""
# Keep this list stable for CLI/help messages.
return ["bf16", "int8", "int4", "fp8_e4m3", "fp8_e5m2", "gptq", "awq", "gptq_awq"]
return ["bf16", "int8", "int4", "fp8_e4m3", "fp8_e5m2", "gptq", "awq", "gptq_awq", "marlin_int8"]


2 changes: 2 additions & 0 deletions diffulex/utils/quantization/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from diffulex.utils.quantization.strategies.linear_bf16 import LinearBF16Strategy
from diffulex.utils.quantization.strategies.linear_stub import LinearStubStrategy
from diffulex.utils.quantization.strategies.linear_int8_w8a16 import LinearInt8W8A16Strategy # noqa: F401
from diffulex.utils.quantization.strategies.linear_marlin_int8_w8a16 import LinearMarlinInt8W8A16Strategy # noqa: F401
from diffulex.utils.quantization.strategies.linear_int4_w4a16 import LinearInt4W4A16Strategy # noqa: F401
from diffulex.utils.quantization.strategies.linear_int8_w8a8 import LinearInt8W8A8Strategy # noqa: F401
from diffulex.utils.quantization.strategies.linear_int4_w4a8 import LinearInt4W4A8Strategy # noqa: F401
Expand All @@ -23,6 +24,7 @@
'LinearBF16Strategy',
'LinearStubStrategy',
'LinearInt8W8A16Strategy',
'LinearMarlinInt8W8A16Strategy',
'LinearInt4W4A16Strategy',
'LinearInt8W8A8Strategy',
'LinearInt4W4A8Strategy',
Expand Down
34 changes: 32 additions & 2 deletions diffulex/utils/quantization/strategies/linear_awq_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@
except ImportError:
awq_w4a16_gemm = None

try:
from diffulex.attention.metadata import is_warming_up
from tilelang.autotuner import set_autotune_inputs
_AUTOTUNE_AVAILABLE = True
except ImportError:
_AUTOTUNE_AVAILABLE = False
is_warming_up = lambda: False
set_autotune_inputs = lambda *args, **kwargs: lambda f: f


def _unpack_awq_int4(
packed: torch.Tensor,
Expand Down Expand Up @@ -184,6 +193,8 @@ class LinearAWQW4A16Strategy(LinearQuantizationStrategy):
def __init__(self):
"""Initialize strategy (no cache needed when using kernel)."""
super().__init__()
# TileLang autotune config cache: (device, M_bucket, N, K, num_groups, group_size) -> config dict
self._tl_autotune_config_cache: dict[tuple[str, int, int, int, int, int], dict] = {}

@property
def name(self) -> str:
Expand Down Expand Up @@ -381,8 +392,27 @@ def linear_forward(
x_pad[:M, :] = x
x_for_kernel = x_pad

# Compile kernel (cached by TileLang)
kernel = awq_w4a16_gemm(M_bucket, N, K, num_groups, group_size, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)
# TileLang autotune: use warmup + config cache pattern
cache_key = (str(x.device), M_bucket, N, K, num_groups, group_size)
config = self._tl_autotune_config_cache.get(cache_key)

if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None:
# Warmup phase: run autotune with real inputs
try:
with set_autotune_inputs([x_for_kernel, qweight, qzeros, scales]):
kernel = awq_w4a16_gemm(M_bucket, N, K, num_groups, group_size)
config = kernel.config
self._tl_autotune_config_cache[cache_key] = config
except Exception:
# Fallback to default config if autotune fails
config = None

# Use cached config or default parameters
if config is not None:
kernel = awq_w4a16_gemm(M_bucket, N, K, num_groups, group_size, **config)
else:
# Default config (backward compatible)
kernel = awq_w4a16_gemm(M_bucket, N, K, num_groups, group_size, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)

# Call kernel - out_idx=[4] means output is the 5th parameter
output_full = kernel(x_for_kernel, qweight, qzeros, scales)
Expand Down
38 changes: 36 additions & 2 deletions diffulex/utils/quantization/strategies/linear_fp8_w8a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
except ImportError:
pass

try:
from diffulex.attention.metadata import is_warming_up
from tilelang.autotuner import set_autotune_inputs
_AUTOTUNE_AVAILABLE = True
except ImportError:
_AUTOTUNE_AVAILABLE = False
is_warming_up = lambda: False
set_autotune_inputs = lambda *args, **kwargs: lambda f: f


@register_linear_strategy(weight_dtype="fp8_e4m3", act_dtype="bf16")
def _build_linear_fp8_e4m3_w8a16() -> LinearQuantizationStrategy:
Expand Down Expand Up @@ -80,6 +89,8 @@ def __init__(self, weight_dtype: str = "fp8_e4m3"):
self._weight_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
# Optional cache: weight_id -> bf16 dequantized weight (speed-first; uses extra memory)
self._dequant_weight_cache: dict[int, torch.Tensor] = {}
# TileLang autotune config cache: (device, M_bucket, N, K) -> config dict
self._tl_autotune_config_cache: dict[tuple[str, int, int, int], dict] = {}

@property
def name(self) -> str:
Expand Down Expand Up @@ -301,8 +312,31 @@ def linear_forward(
x_pad[:M, :] = x
x_for_kernel = x_pad

# Compile kernel (cached by TileLang)
kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)
# TileLang autotune: use warmup + config cache pattern
cache_key = (str(x.device), M_bucket, N, K)
config = self._tl_autotune_config_cache.get(cache_key)

if _AUTOTUNE_AVAILABLE and is_warming_up() and config is None:
# Warmup phase: run autotune with real inputs
try:
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
with set_autotune_inputs([x_for_kernel, qweight_fp8, scales]):
kernel = fp8_w8a16_gemm(M_bucket, N, K)
config = kernel.config
self._tl_autotune_config_cache[cache_key] = config
except Exception:
# Fallback to default config if autotune fails
config = None

# Use cached config or default parameters
assert self.spec.fp8_view_dtype is not None
qweight_fp8 = quantized_weight.view(self.spec.fp8_view_dtype)
if config is not None:
kernel = fp8_w8a16_gemm(M_bucket, N, K, **config)
else:
# Default config (backward compatible)
kernel = fp8_w8a16_gemm(M_bucket, N, K, block_M=64, block_N=64, block_K=128, num_stages=2, threads=128)

# Call kernel - out_idx=[3] means output is the 4th parameter
assert self.spec.fp8_view_dtype is not None
Expand Down
Loading