Skip to content

update get_error_factor to cache up with the latest transformers change#3021

Closed
jiqing-feng wants to merge 5 commits intohuggingface:mainfrom
jiqing-feng:8bit
Closed

update get_error_factor to cache up with the latest transformers change#3021
jiqing-feng wants to merge 5 commits intohuggingface:mainfrom
jiqing-feng:8bit

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

@jiqing-feng jiqing-feng commented Feb 4, 2026

Fix failed tests: RUN_SLOW=1 pytest tests/test_gpu_examples.py::TestLoftQ

FAILED tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cpu] - assert tensor(2.9929e-09, grad_fn=<MeanBackward0>) < (tensor(2.0073e-09, grad_fn=<MeanBackward0>) / 1.005)
FAILED tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cuda] - AssertionError: assert tensor(0.0868, device='cuda:0', grad_fn=<MeanBackward0>) < (tensor(0.0356, device='cuda:0', gr...

Root Cause Analysis:
The tests are failing due to changes in default dtype handling in transformers (ref: PR #42805).

Feature Old Behavior New Behavior
Standard Model Loaded as float32 Loaded as float32
Quantized Model Non-quantized parts (e.g., embeddings) loaded as fp16 Non-quantized parts loaded as float32
Test Impact Tests compared float32 vs fp16 (inconsistent) Tests now compare float32 vs float32 (consistent)

Conclusion: The new behavior is more reasonable. We are updating our test suite to accommodate this change.

With this change, all TestLoftQ tests can pass on CUDA-A100/XPU/CPU.

More details see: huggingface/transformers#43725

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @BenjaminBossan . Would you please review this PR? Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @SunMarc . Would you please review this PR? Thanks!

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for thinking, makes sense

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @githubnemo . Would you please review and merge this PR? Thanks!

@githubnemo
Copy link
Copy Markdown
Collaborator

Thanks for researching this issue and the PR!

I'm not sure if I agree with the change.

It is true that before it was easier for LoftQ to be better since the quantized_model in the test quantized more parameters (layer norm, embedding and lm head were in fp16 in addition to the linear layers quantized in 8bit). But that doesn't change the fact that LoftQ is supposed to be better, it just removes the unfair advantage. Discounting the quantized model's loss by ~2x is just restoring the unfair advantage and possibly hiding bugs.

When running the tests on transformers main only bloomz fails, all other models are fine - can you confirm this to narrow it a bit down?

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Not only bloomz, t5 also.

tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit[cuda] PASSED                                                                              [  5%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit[cpu] PASSED                                                                               [ 10%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit_iter_5[cuda] PASSED                                                                       [ 15%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit_iter_5[cpu] PASSED                                                                        [ 20%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit[cuda] FAILED                                                                              [ 25%] tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit[cpu] FAILED                                                                               [ 30%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cuda] FAILED                                                                       [ 35%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_iter_5[cpu] FAILED                                                                        [ 40%]
tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_4bit[cuda] PASSED                                                                                  [ 45%]
tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_4bit[cpu] PASSED                                                                                   [ 50%]
tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cuda] FAILED                                                                                  [ 55%]
tests/test_gpu_examples.py::TestLoftQ::test_t5_loftq_8bit[cpu] PASSED                                                                                   [ 60%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit_dora[cuda] XFAIL                                                                          [ 65%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_4bit_dora[cpu] XFAIL                                                                           [ 70%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora[cuda] FAILED                                                                         [ 75%]
tests/test_gpu_examples.py::TestLoftQ::test_bloomz_loftq_8bit_dora[cpu] FAILED                                                                          [ 80%]
tests/test_gpu_examples.py::TestLoftQ::test_replace_lora_weights_with_loftq_using_callable PASSED                                                       [ 85%]
tests/test_gpu_examples.py::TestLoftQ::test_replace_lora_weights_with_local_model PASSED                                                                [ 90%]
tests/test_gpu_examples.py::TestLoftQ::test_config_no_loftq_init PASSED                                                                                 [ 95%]
tests/test_gpu_examples.py::TestLoftQ::test_config_no_loftq_config PASSED                                                                               [100%]

All failed tests are 8-bit LoftQ; all 4-bit tests pass.

Previously, due to a transformers default dtype change (PR #42805), the non-quantized parameters (embeddings, layer norms, etc.) of the plain quantized model were loaded in fp16, while the LoftQ model used fp32. This dtype mismatch inflated the plain quantized model's error, making LoftQ appear better by comparison.

After the fix (explicitly passing dtype=torch.float32), both models use fp32, making the comparison fair. Under this fair comparison:

4-bit: Quantization error is large enough that LoftQ's SVD-based residual initialization still meaningfully reduces it → tests pass.
8-bit: Quantization already preserves high precision, so LoftQ cannot further reduce the error → tests fail.
The 8-bit LoftQ tests need their assertions adjusted accordingly.

@githubnemo

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

To prove it, please run the following script:

"""
Diagnose LoftQ 8-bit test failure: compare fp16 8-bit vs fp32 8-bit vs LoftQ (fp32).

Usage: python scripts/diagnose_loftq_8bit.py
"""

import tempfile
from pathlib import Path

import torch
from accelerate.utils.memory import clear_device_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from peft import LoftQConfig, LoraConfig, PeftModel, TaskType, get_peft_model

model_id = "peft-internal-testing/tiny-random-BloomForCausalLM"
device = "cuda" if torch.cuda.is_available() else "cpu"
bnb_8bit = BitsAndBytesConfig(load_in_8bit=True)


def load_model(device, **kwargs):
    return AutoModelForCausalLM.from_pretrained(model_id, device_map=device, **kwargs).eval()


with tempfile.TemporaryDirectory() as tmp_dir:
    tmp_path = Path(tmp_dir)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    inputs = tokenizer("All I want is", padding=True, return_tensors="pt").to(device)

    # Base model (fp32)
    torch.manual_seed(0)
    logits_base = load_model(device)(**inputs).logits

    # 8-bit with fp16 non-quantized params (OLD behavior)
    torch.manual_seed(0)
    m = get_peft_model(load_model(device, quantization_config=bnb_8bit, dtype=torch.float16),
                       LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear"))
    logits_8bit_fp16 = m(**inputs).logits
    del m; clear_device_cache(garbage_collection=True)

    # 8-bit with fp32 non-quantized params (NEW behavior)
    torch.manual_seed(0)
    m = get_peft_model(load_model(device, quantization_config=bnb_8bit, dtype=torch.float32),
                       LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules="all-linear"))
    logits_8bit_fp32 = m(**inputs).logits
    del m; clear_device_cache(garbage_collection=True)

    # LoftQ (fp32)
    lora_cfg = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights="loftq",
                          loftq_config=LoftQConfig(loftq_bits=8), target_modules="all-linear")
    m = get_peft_model(load_model(device).to(device), lora_cfg).to(device)
    m.base_model.peft_config["default"].init_lora_weights = True
    m.save_pretrained(tmp_path / "loftq")
    m.unload().save_pretrained(tmp_path / "base")
    del m; clear_device_cache(garbage_collection=True)

    base = load_model(device, quantization_config=bnb_8bit, dtype=torch.float32)
    m = PeftModel.from_pretrained(base, tmp_path / "loftq", is_trainable=True)
    torch.manual_seed(0)
    logits_loftq = m(**inputs).logits
    del m; clear_device_cache(garbage_collection=True)

    # Compute MAE
    mae = lambda l: torch.abs(logits_base - l).mean().item()
    mae_fp16, mae_fp32, mae_lq = mae(logits_8bit_fp16), mae(logits_8bit_fp32), mae(logits_loftq)

print(f"\n{'='*70}")
print(f"  8-bit quantized (fp16 non-quant params):  MAE = {mae_fp16:.6e}")
print(f"  8-bit quantized (fp32 non-quant params):  MAE = {mae_fp32:.6e}")
print(f"  LoftQ 8-bit    (fp32 non-quant params):   MAE = {mae_lq:.6e}")
print(f"{'='*70}")
print(f"  OLD test (fp16 8bit vs LoftQ): LoftQ better? {mae_lq < mae_fp16}")
print(f"  NEW test (fp32 8bit vs LoftQ): LoftQ better? {mae_lq < mae_fp32}")
print(f"{'='*70}")

The output on A100 is

======================================================================
  8-bit quantized (fp16 non-quant params):  MAE = 8.185369e-05
  8-bit quantized (fp32 non-quant params):  MAE = 4.464298e-05
  LoftQ 8-bit    (fp32 non-quant params):   MAE = 4.547583e-05
======================================================================
  OLD test (fp16 8bit vs LoftQ): LoftQ better? True
  NEW test (fp32 8bit vs LoftQ): LoftQ better? False
======================================================================

You can see the previous 8bit model (fp16) MAE is much bigger than the new 8bit model (fp32), that's why the tests failed.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @BenjaminBossan . Would you please review this PR? Thanks!

@BenjaminBossan
Copy link
Copy Markdown
Member

We're still trying to get to the bottom of this, it's not quite trivial to understand what's going on.

@githubnemo
Copy link
Copy Markdown
Collaborator

@jiqing-feng thank you for your patience.

I have investigated this closer and confirmed my initial suspicion that there is a more fundamental issue that is only highlighted by the recent changes in dtype conversion. You are correct insofar that the comparison is now more fair since all layers are float32. But adjusting the error factors to allow LoftQ be worse than the unmodified quantized model is wrong: LoftQ is supposed to be better in recovering the quantization error than a quantized base model that doesn't have a mechanism to recover the quantization error at all. This assumption holds for nf4 (i.e. "4bit quantization") but breaks for 8bit, as you reported.

So, what makes 8 bit special? The actual issue here is that LoftQ never supported int8 quantization but this fact was masked by the inequality you reported. The 8-bit mode of LoftQ assumes 8-bit normal floats but quantizing with bitsandbytes in 8bit uses int8. This is inherently incompatible and never worked.

I am working on a fix but most likely we will just drop support for LoftQ 8bit since it is misleading and doesn't work. It is also not as simple as mimicking the 4-bit behavior since in int8 mode, bnb also quantizes the input which is not easily covered by a static LoRA initialization like LoftQ.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @githubnemo . Thanks! It's very clear to me now! Since then, it would be better to skip 8bit loftq test. Please review my new changes. Thanks!

githubnemo pushed a commit to githubnemo/peft that referenced this pull request Mar 10, 2026
With transformers v5 the GPU examples started failing for LoftQ,
mainly for 8bit quantization. This is also reported in PR huggingface#3021.
The main reason for this is that with transformers v5 the dtype
of the layers is consistent across the model whereas before
some layers were in float16 and some in float32. This masked
wrong behavior.

It turns out that LoftQ always assumed that if `n_bits==8` the
weights are in fp8. In reality, at least for bitsandbytes, this
is not true. bnb's `load_in_8bit` uses int8. This PR changes
the quantization residual computation to be compatible with int8
if `n_bits==8`, so that at least the computed residual for
the loftq initialization are correct.

This alone does not mean that LoftQ tests for 8bit are passing.
They assume that the MSE of LoftQ 'corrected' models compared to
an unquantized baseline model are at least 3% truer than the
logits of the quantized reference model. For int8 this is not
guaranteed since `bnb.nn.Linear8bitLt` not only quantizes the
weights but it also quantizes the inputs in it's `.forward()`.
The resulting quantization error cannot be compensated by the
static LoftQ initialization. Since there is technically a way
to compensate for this error, I've kept the LoftQ initialization
for int8, added an example and changed the 8 bit tests to only
assume rough parity with the quantized reference model.

The example that show-cases how to recover the activation
quantization error can also be run with `--no-loftq` which skips
the LoftQ initialization and uses the LoRA no-op initialization
instead and the example still recovers the majority of the error
(e.g. 90% improvement instead of 91% over quantized reference).

Therefore there are the following open tasks:

- discuss removal of int8 from LoftQ completely (since it can
  be argued that it doesn't seem to have such a big impact but
  this needs review and possibly better testing)
- discuss the now failing (marked as xfail) tests for nf4
  for t5 when loftq_iter > 1. I don't have an intuition what's
  happening here but the error seems to roughly double per
  iteration.
@githubnemo
Copy link
Copy Markdown
Collaborator

@jiqing-feng thanks. Let's see what #3088 brings first.

@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @githubnemo . With your change, I can pass all loftq tests on NV A100 without any changes for test files. I will close this PR once your PR is merged.

githubnemo added a commit that referenced this pull request Mar 12, 2026
* Partial fix for LoftQ + int8 quantization

With transformers v5 the GPU examples started failing for LoftQ,
mainly for 8bit quantization. This is also reported in PR #3021.
The main reason for this is that with transformers v5 the dtype
of the layers is consistent across the model whereas before
some layers were in float16 and some in float32. This masked
wrong behavior.

It turns out that LoftQ always assumed that if `n_bits==8` the
weights are in fp8. In reality, at least for bitsandbytes, this
is not true. bnb's `load_in_8bit` uses int8. This PR changes
the quantization residual computation to be compatible with int8
if `n_bits==8`, so that at least the computed residual for
the loftq initialization are correct.

This alone does not mean that LoftQ tests for 8bit are passing.
They assume that the MSE of LoftQ 'corrected' models compared to
an unquantized baseline model are at least 3% truer than the
logits of the quantized reference model. For int8 this is not
guaranteed since `bnb.nn.Linear8bitLt` not only quantizes the
weights but it also quantizes the inputs in it's `.forward()`.
The resulting quantization error cannot be compensated by the
static LoftQ initialization. Since there is technically a way
to compensate for this error, I've kept the LoftQ initialization
for int8, added an example and changed the 8 bit tests to only
assume rough parity with the quantized reference model.

The example that show-cases how to recover the activation
quantization error can also be run with `--no-loftq` which skips
the LoftQ initialization and uses the LoRA no-op initialization
instead and the example still recovers the majority of the error
(e.g. 90% improvement instead of 91% over quantized reference).

Therefore there are the following open tasks:

- discuss removal of int8 from LoftQ completely (since it can
  be argued that it doesn't seem to have such a big impact but
  this needs review and possibly better testing)
- discuss the now failing (marked as xfail) tests for nf4
  for t5 when loftq_iter > 1. I don't have an intuition what's
  happening here but the error seems to roughly double per
  iteration.

* Adjust MSE computation for padding
* Use multiple inputs instead of only one

This solves some errors, it uncovered some instabilites
that were fixed using a higher rank.


---------

Co-authored-by: nemo <git@ningu.net>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@githubnemo
Copy link
Copy Markdown
Collaborator

@jiqing-feng thanks for confirming. PR is merged now so I'll close this.

Thanks again for raising this issue!

@githubnemo githubnemo closed this Mar 12, 2026
kashif pushed a commit to kashif/peft that referenced this pull request Mar 13, 2026
* Partial fix for LoftQ + int8 quantization

With transformers v5 the GPU examples started failing for LoftQ,
mainly for 8bit quantization. This is also reported in PR huggingface#3021.
The main reason for this is that with transformers v5 the dtype
of the layers is consistent across the model whereas before
some layers were in float16 and some in float32. This masked
wrong behavior.

It turns out that LoftQ always assumed that if `n_bits==8` the
weights are in fp8. In reality, at least for bitsandbytes, this
is not true. bnb's `load_in_8bit` uses int8. This PR changes
the quantization residual computation to be compatible with int8
if `n_bits==8`, so that at least the computed residual for
the loftq initialization are correct.

This alone does not mean that LoftQ tests for 8bit are passing.
They assume that the MSE of LoftQ 'corrected' models compared to
an unquantized baseline model are at least 3% truer than the
logits of the quantized reference model. For int8 this is not
guaranteed since `bnb.nn.Linear8bitLt` not only quantizes the
weights but it also quantizes the inputs in it's `.forward()`.
The resulting quantization error cannot be compensated by the
static LoftQ initialization. Since there is technically a way
to compensate for this error, I've kept the LoftQ initialization
for int8, added an example and changed the 8 bit tests to only
assume rough parity with the quantized reference model.

The example that show-cases how to recover the activation
quantization error can also be run with `--no-loftq` which skips
the LoftQ initialization and uses the LoRA no-op initialization
instead and the example still recovers the majority of the error
(e.g. 90% improvement instead of 91% over quantized reference).

Therefore there are the following open tasks:

- discuss removal of int8 from LoftQ completely (since it can
  be argued that it doesn't seem to have such a big impact but
  this needs review and possibly better testing)
- discuss the now failing (marked as xfail) tests for nf4
  for t5 when loftq_iter > 1. I don't have an intuition what's
  happening here but the error seems to roughly double per
  iteration.

* Adjust MSE computation for padding
* Use multiple inputs instead of only one

This solves some errors, it uncovered some instabilites
that were fixed using a higher rank.


---------

Co-authored-by: nemo <git@ningu.net>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@jiqing-feng jiqing-feng deleted the 8bit branch April 8, 2026 01:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants