Skip to content

Add after-load fusion for static quantized MLPs#46997

Open
LiangSu8899 wants to merge 1 commit into
huggingface:mainfrom
LiangSu8899:static-quantized-mlp-fusion
Open

Add after-load fusion for static quantized MLPs#46997
LiangSu8899 wants to merge 1 commit into
huggingface:mainfrom
LiangSu8899:static-quantized-mlp-fusion

Conversation

@LiangSu8899

@LiangSu8899 LiangSu8899 commented Jul 1, 2026

Copy link
Copy Markdown

What does this PR do?

This PR adds an opt-in after-load fusion path for static quantized MLP modules.

Today, fusion_config supports pre-initialization module fusion such as patch embeddings. Some quantized fusions, however, need the quantized weights and frozen scales to be available first. Static FP8 MLP fusion is one of those cases: the replacement module needs the loaded FP8 weights, activation scales, and weight scales before it can safely dispatch to a Hub kernel.

This PR extends the fusion mapping machinery with after-load fusion specs and wires that path into the FineGrainedFP8 quantizer after weight loading. The first supported after-load fusion family is static_quantized_mlp.

Why this matters:

  • static quantized kernels can reuse frozen activation and weight scales instead of recomputing scales on every forward;
  • this is more CUDA-graph friendly for low-latency decode;
  • fused MLP kernels avoid splitting the MLP into multiple smaller ops and materializing the intermediate hidden state;
  • model integrations can move from per-model replacement/calibration/fallback glue to a config-driven path that is reusable across architectures.

External validation from a prototype HF kernels integration:

  • Standalone GeGLU MLP on RTX 5090, K=2048, H=16384:
    • decode-like M=50: BF16 0.151 ms -> static FP8 fused 0.070 ms, about 2.1x faster;
    • prefill-like M=512: BF16 0.581 ms -> static FP8 fused 0.275 ms, about 2.1x faster;
    • compared with a dynamic FP8 path, the static fused path was still about 1.2x-1.6x faster.
  • Full Pi0.5 / LeRobot VLA all-FP8 path on RTX 5090:
    • CUDA graph: static 21.56 ms vs dynamic 36.23 ms, about 1.68x faster;
    • eager: static 24.85 ms vs dynamic 47.97 ms, about 1.93x faster;
    • output quality against a BF16 reference:
      • static FP8: cosine 0.9999275, max abs 0.0195, p99 abs 0.0156, MSE 1.675e-05;
      • dynamic FP8: cosine 0.9998960, max abs 0.0215, p99 abs 0.0195, MSE 2.008e-05.
  • Task-level quality: Pi0.5 LIBERO Spatial FP8 with real-data recalibration keeps the original model success level at 98.2%.

The detailed reproduction notes are available at https://github.com/flashrt-project/FlashRT-HF-kernels/blob/main/docs/static-vs-dynamic-fp8.md, and calibration robustness notes are available at https://github.com/flashrt-project/FlashRT/blob/main/docs/calibration.md. These are external validation results and are not required by the Transformers test suite.

The implementation is intentionally opt-in and backend-neutral:

  • no default kernel repository is selected;
  • users must provide explicit Hub kernel repositories through model.config.fusion_config;
  • unsupported modules, activations, or quantization layouts fall back to the existing module path;
  • dynamic FP8, blockwise FP8, calibration, and training are out of scope for this PR.

Example config shape:

fusion_config = {
    "static_quantized_mlp": {
        "input_quant": {
            "repo_id": "owner/static-fp8-input-quant",
            "version": 1,
        },
        "gated_mlp": {
            "repo_id": "owner/static-fp8-gated-mlp",
            "version": 1,
        },
        "dense_gelu_mlp": {
            "repo_id": "owner/static-fp8-dense-gelu-mlp",
            "version": 1,
        },
    }
}

Supported in this PR:

  • static per-tensor FineGrainedFP8 FP8Linear modules;
  • gated MLPs with gate_proj, up_proj, down_proj and SiLU or tanh-approx GELU activations;
  • dense GELU MLPs with fc1, fc2, and tanh-approx GELU activation;
  • explicit Hub kernel endpoint validation and function-name overrides.

This is meant to be a narrow mechanism extension that can later be reused for related static quantized fusions, for example NVFP4/MXFP4 MLP kernels or additional MLP layouts.

Fixes # (issue)

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

This PR is opened as a draft so the author can review all changed lines before marking the Code Agent Policy checkbox and moving it to ready for review.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline and the
    Pull Request checks?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes according to the guidelines?
  • Did you write any new necessary tests?

Tests

PYTHONPATH=src PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest -p xdist.plugin \
  tests/quantization/finegrained_fp8/test_fp8.py::StaticQuantizedMLPFusionTest \
  tests/quantization/finegrained_fp8/test_fp8.py::FineGrainedFP8ConfigTest -q

Result: 8 passed, 1 warning

PYTHONPATH=src python -m pytest tests/utils/test_fusion_mapping.py -q

Result: 4 passed, 1 warning

PYTHONPATH=src python -m ruff check \
  src/transformers/integrations/static_quantized_mlp.py \
  src/transformers/fusion_mapping.py \
  src/transformers/quantizers/quantizer_finegrained_fp8.py \
  src/transformers/integrations/__init__.py \
  tests/quantization/finegrained_fp8/test_fp8.py

Result: All checks passed!

python -m ruff format \
  src/transformers/integrations/static_quantized_mlp.py \
  tests/quantization/finegrained_fp8/test_fp8.py \
  --check

Result: 2 files already formatted

PATH=/path/to/venv/bin:$PATH python utils/check_copies.py
git diff --check

Result: passed

Reviewer notes

The main API question is whether fusion_config["static_quantized_mlp"] is the right public surface for after-load quantized fusion specs, or whether maintainers would prefer a quantizer-owned registration mechanism.

Potential reviewers: quantization, kernels, and model loading maintainers.

@github-actions

github-actions Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8

@LiangSu8899 LiangSu8899 force-pushed the static-quantized-mlp-fusion branch from 7cd4640 to 3d7f4ad Compare July 1, 2026 10:49
@LiangSu8899 LiangSu8899 marked this pull request as ready for review July 1, 2026 10:54
@Rocketknight1

Copy link
Copy Markdown
Member

cc @SunMarc maybe?

@github-actions

github-actions Bot commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28512120492:2
Result: failure | Jobs: 14 | Tests: 71,321 | Failures: 0 | Duration: 17h 21m

@sayakpaul sayakpaul requested a review from vasqu July 1, 2026 11:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants