Skip to content

fix: apply per-tensor weight scales during FP8 dequantization#172

Open
Julio G (sumurtk2) wants to merge 1 commit intoLightricks:mainfrom
sumurtk2:fix/fp8-scaled-checkpoint-dequant
Open

fix: apply per-tensor weight scales during FP8 dequantization#172
Julio G (sumurtk2) wants to merge 1 commit intoLightricks:mainfrom
sumurtk2:fix/fp8-scaled-checkpoint-dequant

Conversation

@sumurtk2
Copy link

Problem

Using --quantization fp8-cast with the pre-quantized FP8 checkpoint (ltx-2.3-22b-dev-fp8.safetensors) produces static noise instead of coherent video. This affects all pipelines (ti2vid_two_stages, ti2vid_one_stage, distilled, etc.) when using the FP8 checkpoint with fp8-cast mode on non-Hopper GPUs (where fp8-scaled-mm / TensorRT-LLM is not available).

Root Cause

The FP8 checkpoint stores weights in float8_e4m3fn format with per-tensor scale factors (weight_scale and input_scale tensors alongside each weight). For example:

transformer_blocks.10.attn1.to_k.weight       → float8_e4m3fn [4096, 4096]
transformer_blocks.10.attn1.to_k.weight_scale → float32 scalar (0.0013)
transformer_blocks.10.attn1.to_k.input_scale  → float32 scalar (0.0157)

The fp8-cast upcast path (_upcast_and_round) performs weight.to(bfloat16) — a raw type cast that ignores the scale factors. This produces values ~770x too large:

# Without scale (current behavior): std = 26.9  ← WRONG
w_up = weight.to(bfloat16)

# With scale (correct behavior):    std = 0.035 ← CORRECT
w_up = weight.to(bfloat16) * weight_scale

Additionally, load_state_dict(strict=False) silently drops the weight_scale tensors since they don't match any model parameters, so the scale information is lost before _amend_forward_with_upcast runs.

Fix

Two changes, both in ltx-core:

  1. single_gpu_model_builder.py: Extract weight_scale tensors from the state dict before load_state_dict discards them, and stash them on the model as _fp8_weight_scales.

  2. fp8_cast.py: _amend_forward_with_upcast now retrieves the stashed scales and passes them to _replace_fwd_with_upcast, which multiplies the dequantized weight by the scale factor during inference.

Backward Compatibility

When no weight_scale tensors are present (e.g., when using fp8-cast to quantize a BF16 checkpoint on the fly), the behavior is unchangedweight_scale defaults to None and the multiplication is skipped.

Testing

Verified on RTX 5090 (Blackwell, sm_120) with:

  • ltx-2.3-22b-dev-fp8.safetensors checkpoint
  • gemma-3-12b-it-qat-q4_0-unquantized text encoder
  • One-stage pipeline at 512×320, 30 steps → produces coherent video with audio

Before fix: static noise (every generation, every seed)
After fix: correct video output matching prompt

Related Issues

The pre-quantized FP8 checkpoint (ltx-2.3-22b-dev-fp8.safetensors) stores
weights in float8_e4m3fn format with per-tensor weight_scale and input_scale
factors. When using fp8-cast quantization mode, the upcast path performs a
naive .to(bfloat16) without applying these scale factors, producing weight
values that are ~770x too large and resulting in noise output instead of
coherent video.

This commit fixes the issue by:

1. Extracting weight_scale tensors from the state dict before
   load_state_dict(strict=False) discards them (in SingleGPUModelBuilder)
2. Passing the scales through to _replace_fwd_with_upcast via the model
3. Multiplying dequantized weights by their scale factor during inference

The fix is backward-compatible: when no weight_scale tensors are present
(e.g. when using fp8-cast to quantize a BF16 checkpoint on the fly), the
behavior is unchanged.

Fixes noise output when running:
  python -m ltx_pipelines.ti2vid_two_stages \
    --checkpoint-path ltx-2.3-22b-dev-fp8.safetensors \
    --quantization fp8-cast ...

Related: Lightricks#165
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant