fix: apply per-tensor weight scales during FP8 dequantization#172
Open
Julio G (sumurtk2) wants to merge 1 commit intoLightricks:mainfrom
Open
fix: apply per-tensor weight scales during FP8 dequantization#172Julio G (sumurtk2) wants to merge 1 commit intoLightricks:mainfrom
Julio G (sumurtk2) wants to merge 1 commit intoLightricks:mainfrom
Conversation
The pre-quantized FP8 checkpoint (ltx-2.3-22b-dev-fp8.safetensors) stores
weights in float8_e4m3fn format with per-tensor weight_scale and input_scale
factors. When using fp8-cast quantization mode, the upcast path performs a
naive .to(bfloat16) without applying these scale factors, producing weight
values that are ~770x too large and resulting in noise output instead of
coherent video.
This commit fixes the issue by:
1. Extracting weight_scale tensors from the state dict before
load_state_dict(strict=False) discards them (in SingleGPUModelBuilder)
2. Passing the scales through to _replace_fwd_with_upcast via the model
3. Multiplying dequantized weights by their scale factor during inference
The fix is backward-compatible: when no weight_scale tensors are present
(e.g. when using fp8-cast to quantize a BF16 checkpoint on the fly), the
behavior is unchanged.
Fixes noise output when running:
python -m ltx_pipelines.ti2vid_two_stages \
--checkpoint-path ltx-2.3-22b-dev-fp8.safetensors \
--quantization fp8-cast ...
Related: Lightricks#165
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Using
--quantization fp8-castwith the pre-quantized FP8 checkpoint (ltx-2.3-22b-dev-fp8.safetensors) produces static noise instead of coherent video. This affects all pipelines (ti2vid_two_stages,ti2vid_one_stage,distilled, etc.) when using the FP8 checkpoint withfp8-castmode on non-Hopper GPUs (wherefp8-scaled-mm/ TensorRT-LLM is not available).Root Cause
The FP8 checkpoint stores weights in
float8_e4m3fnformat with per-tensor scale factors (weight_scaleandinput_scaletensors alongside each weight). For example:The
fp8-castupcast path (_upcast_and_round) performsweight.to(bfloat16)— a raw type cast that ignores the scale factors. This produces values ~770x too large:Additionally,
load_state_dict(strict=False)silently drops theweight_scaletensors since they don't match any model parameters, so the scale information is lost before_amend_forward_with_upcastruns.Fix
Two changes, both in
ltx-core:single_gpu_model_builder.py: Extractweight_scaletensors from the state dict beforeload_state_dictdiscards them, and stash them on the model as_fp8_weight_scales.fp8_cast.py:_amend_forward_with_upcastnow retrieves the stashed scales and passes them to_replace_fwd_with_upcast, which multiplies the dequantized weight by the scale factor during inference.Backward Compatibility
When no
weight_scaletensors are present (e.g., when usingfp8-castto quantize a BF16 checkpoint on the fly), the behavior is unchanged —weight_scaledefaults toNoneand the multiplication is skipped.Testing
Verified on RTX 5090 (Blackwell, sm_120) with:
ltx-2.3-22b-dev-fp8.safetensorscheckpointgemma-3-12b-it-qat-q4_0-unquantizedtext encoderBefore fix: static noise (every generation, every seed)
After fix: correct video output matching prompt
Related Issues
fp8-scaled-mmcrashes with TypeError — the only other FP8 path)