-
Notifications
You must be signed in to change notification settings - Fork 10
Feat/kv cache fp8 support #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/kv cache fp8 support #23
Conversation
…s; remove unused checker.py
…rom global memory fetching into fragment fetching
…ilable, checking errors of cuda graph capturing fixed.
… and WARP_SPECIALIZATION
…and pull requests
- Add LinearQuantizationStrategy interface supporting weight+activation quantization - Support layer-type-specific strategies (attn/mlp/other) - Add registry system for linear quantization strategies - Add Config fields: linear_attn_weight_dtype, linear_mlp_weight_dtype, linear_attn_act_dtype, linear_mlp_act_dtype - Integrate factory to inject strategies into QuantizationContext - Add dynamic dispatch in Linear.forward() based on quant_kind - Tag Linear layers in models (dream/llada/sdar/fast_dllm_v2) with quant_kind - Add placeholder strategies (stub) that raise NotImplementedError for non-bf16 dtypes - Add unit tests for registry/factory/dispatch behavior - Default bf16 behavior unchanged (fully backward compatible) All non-bf16 paths currently raise NotImplementedError with clear error messages, providing stable interface for future kernel/packed weight implementations.
…6 activation) - Add LinearInt8W8A16Strategy with per-channel symmetric quantization - Reference implementation using Python dequantization + F.linear - Quantization: per-output-channel scales, int8 weight storage - Activation: remains bf16 (no activation quantization) - Update tests to verify W8A16 strategy (quantization/forward correctness) - Update placeholder documentation with implementation status Performance notes: - Current implementation quantizes weights on every forward (no caching) - Future optimization: lazy cache quantized weights per module instance - Future optimization: replace F.linear with custom int8 GEMM kernel This provides a working reference implementation for W8A16 quantization, enabling correctness validation before moving to optimized kernels.
- Add weight quantization cache keyed by weight tensor id() - Cache stores (quantized_weight, scales) tuple per weight - First forward quantizes and caches, subsequent forwards reuse cache - Add clear_cache() method for memory management - Add unit test to verify cache behavior Performance improvement: - Eliminates redundant quantization on every forward pass - Significant speedup for decode phase (where same weights are reused) - Cache automatically handles device placement This addresses the performance concern mentioned in the placeholder documentation, where every forward was re-quantizing weights.
- Implement W8A16 GEMM kernel using TileLang with per-channel dequantization - Integrate kernel into LinearInt8W8A16Strategy with robust error handling - Add comprehensive error handling: CUDA device checks, compute capability detection, shape constraints - Automatic fallback to Python reference implementation when kernel unavailable - Add unit tests for kernel correctness and lazy cache functionality - Update documentation to reflect implementation status Performance: Prefill ~110 tok/s, Decode ~43 tok/s (with cached kernels)
…ribution guidelines
- 添加 warmup 参数到 test_generation 函数,排除 kernel 编译影响 - 每条路径(BF16+BF16 KV 和 BF16+FP8 KV)先运行 warmup,再运行实际测试 - 添加性能对比输出,对比两条路径的 TPS 和时间差异 - 改进输出格式,显示详细的性能指标和对比结果
…near layers - Add load-time quantization in LinearBase._maybe_quantize_loaded_weight_param() - Quantize weights during weight_loader and store as int8 buffers - Remove original bf16 weight Parameter to save GPU memory (~2x reduction) - Handle multi-shard weights (QKV/Merged) by waiting for all shards before replacement - Update LinearInt8W8A16Strategy to consume quantized buffers directly - Skip lazy cache when load-time quantized buffers are present - Add M-bucketing for prefill to reduce kernel compilation overhead - Optimize TileLang W8A16 kernel to handle tail dimensions - Implement dual-path kernel (aligned vs tail-safe) using masking - Remove K dimension alignment requirement, preventing fallbacks - Add comprehensive tests for load-time quantization - Verify weight Parameter removal and buffer usage - Test memory savings and numerical correctness - Update test_w8a16_generation.py with W8A16+FP8 KV mixed path performance comparison
…dequant to output scaling - Move per-channel scale multiplication from K-loop weight dequant to output column scaling - Mathematical equivalence: (A @ (q*s)^T) = (A @ q^T) * s for per-channel scales - Reduces register pressure, type conversions, and intermediate buffers in hot path - Applied to both w8a16_gemm and w4a16_gemm kernels - Fix test_w8a16_tilelang_kernel_correctness: use masked relative error check - Avoids false failures when ref_output is near zero - Only checks relative error where ref_output.abs() > 1.0 - Improve test_w8a16_generation.py cleanup logic - Ensure proper cleanup (destroy_process_group, empty_cache, gc.collect) even on exceptions - Add W4A16 strategy implementation and test script
[CI] Add Dependabot.yml into GitHub Workflow
- Change weight scales dtype from BF16 to FP16 for W8A8/W4A8 strategies to reduce quantization errors - Update w8a8_scaled_gemm and w4a8_scaled_gemm kernels to accept FP16 scales instead of BF16 - Add W8A8 and W4A8 quantization strategies (linear_int8_w8a8.py, linear_int4_w4a8.py) - Merge test scripts into unified test_quantization_generation.py - Add mixed precision option for W4A8 (MLP A8 + Attn A16) to improve quality
… modify uvicorn index URL, and improve error handling in attention module; remove unused profiling function from example scripts
…e models, datasets, and logging; add configuration management and command-line interface
…ex; introduce logger module and refactor existing code to utilize logging instead of print statements
… DP worker and evaluation collapse when DP enabled
…ar backends and comprehensive metrics collection
[Feat] Enhance Decoding Strategies for Easier Development and More Efficient Inference
- Delete AttnQ strategy implementations (attn_q_bf16.py, attn_q_fp8_stub.py) - Remove AttnQQuantizationStrategy base class from strategy.py - Remove attn_q related methods from context.py (get_attn_q_strategy, set_attn_q_strategy) - Remove attn_q registry functions from registry.py (register_attn_q_strategy, create_attn_q_strategy, registered_attn_q_dtypes) - Remove attn_q exports from __init__.py - Remove attn_q_dtype from config.py (ActivationQuantConfig) - Remove attn_q strategy creation from factory.py - Update kernel code (dllm_flash_attn.py) to use fixed BF16 for Q (removed get_attn_q_strategy calls) - Remove q_scale field from _AttnMetaDataLike protocol
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the Diffulex project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
…port # Conflicts: # diffulex/__init__.py # diffulex/engine/model_runner.py # diffulex_kernel/__init__.py # diffulex_kernel/python/dllm_flash_attn_kernels.py # test/python/test_linear_fp8.py # test/python/test_linear_quantization_module.py # test/python/test_quantization_e2e.py # test/python/test_quantization_module.py # test/python/test_quantization_paths.py # test/test_gptq_awq_strategies.py
- 修复 update_scales 方法中 scale 和 absmax 转换的逻辑错误 - 现在正确地将 scale 转换为 absmax 后再进行比较和更新 - 符合 vLLM 的 RunningMax 实现方式 - 添加了详细的注释说明更新流程 - 更新了量化测试脚本和配置文件
- 从 git 跟踪中移除 .cursor 目录 - 将 .cursor/ 添加到 .gitignore 以避免将来误提交
67686e0 to
7b15d65
Compare
- Optimize W8A16 small-M decode: pad M<16 to 16 (instead of 64) and use block_M=16/32/64 - Add w8a16_gemm_bias kernel with fused bias epilogue (opt-in via DIFFULEX_W8A16_FUSE_BIAS) - Add runtime profiling hooks for W8A16 (DIFFULEX_LINEAR_PROFILE) to track M distribution and fallbacks - Implement FP8 KV varlen fused dequantization kernel (Triton) for unified layout - Add benchmark configs for W4A8 and W8A8 quantization strategies - Add profiling hooks for KV cache load timing (DIFFULEX_PROFILE_KVCACHE)
主要新增内容:
1. **Marlin/AllSpark INT8 W8A16 量化策略集成**:
- 新增 linear_marlin_int8_w8a16.py:实现基于 vLLM AllSpark kernel 的 W8A16 量化策略
- 新增 diffulex_kernel/csrc/marlin/:vendored vLLM 的 AllSpark CUDA kernels
* allspark_qgemm_w8a16.cu: W8A16 fused GEMM kernel
* allspark_repack.cu: N32K16 权重重排 kernel
* allspark_utils.cuh: 工具函数和数据结构
* torch_bindings_marlin.cpp: PyTorch C++ 绑定
- 新增 diffulex_kernel/python/marlin_ops.py:Python 接口用于 JIT 编译和加载 Marlin/AllSpark kernels
2. **量化策略注册更新**:
- 在 registry.py 中添加 'marlin' 别名支持(映射到 marlin_int8)
- 在 strategies/__init__.py 中导入新的策略
3. **性能改进**:
- Marlin W8A16 策略显著提升了 Prefill 吞吐量(从 4518.92 tok/s 提升到 9520.91 tok/s,约 2.1 倍)
- Decode 吞吐量接近 BF16 基线(23.16 tok/s vs 23.36 tok/s)
- 支持与 FP8 KV cache 组合使用
4. **其他改进**:
- 优化了多个量化策略的实现
- 改进了 KV cache 管理
- 增强了 profiler 功能
- 新增了多个 benchmark 配置文件
4a6e365
into
SJTU-DENG-Lab:feat/kv-cache-fp8-support
Add Linear Quantization Support