Skip to content

Adds Quantization Error Propagation (QEP) Algorithm#213

Open
copybara-service[bot] wants to merge 1 commit intomainfrom
test_870737714
Open

Adds Quantization Error Propagation (QEP) Algorithm#213
copybara-service[bot] wants to merge 1 commit intomainfrom
test_870737714

Conversation

@copybara-service
Copy link

Adds Quantization Error Propagation (QEP) Algorithm

Summary

Add QEP (Quantization Error Propagation) support to Qwix, extending GPTQ to account for quantization noise in input activations from previous layers. QEP applies a Hessian-based weight correction before GPTQ quantization, reducing output error when both weights and activations are quantized.

Algorithm Overview

The Problem

Standard GPTQ minimizes ||W @ X - W_q @ X||^2. This represents the error when only weights are quantized and inputs remain in float. In real quantized models, inputs are also quantized by prior layers, so the actual inference error is ||W @ X_float - W_q @ X_q||^2. GPTQ ignores this input quantization noise, losing out on some accuracy gains.

QEP Solution

QEP decomposes the quantized-input error into two components:

||W @ X_float - W_q @ X_q||^2
= ||(W - W_q) @ X_q  +  W @ (X_float - X_q)||^2
    (weight error)           (input error)

To compensate, QEP collects two statistics during calibration:

  • Hessian H = X_q @ X_q^T — input covariance from quantized activations
  • Hessian delta H_delta = (X_float - X_q) @ X_q^T — cross-correlation between input quantization error and quantized activations

Then applies a weight correction before quantization:

W_corrected = W + alpha * (W @ H_delta @ inv(H))

where alpha (correction factor, default 0.5) controls correction strength. This adjusts weights so that W_corrected @ X_q ~= W @ X_float, partially canceling the effect of quantized inputs. The corrected weights are then quantized using GPTQ for optimal rounding.

Implementation

QEP plugs into Qwix's existing interception-based quantization framework:

  1. Configuration: QepRule extends GptqRule with two additional parameters:

    • correction_factor (default 0.5): controls weight correction strength
    • dampening_factor (default 0.01): dampening for QEP Hessian inversion.
  2. Calibration (QepCalibrationProvider): Extends StatsCalibrationProvider but overrides dot_general directly (rather than using compute_stats) to implement a two-pass protocol per batch via calibrate_batch():

    • Float pass: Forward pass with original float params; caches float-precision LHS activations in a Python-side dict, keyed by module_path/var_name
    • Quantized pass: Forward pass with dequantized PTQ params; computes H and H_delta at each intercepted dot_general using cached float LHS and current quantized LHS; accumulates via SimpleMovingAverage
    • Must run eagerly (outside jax.jit) because the LHS cache is Python-side state
  3. Weight quantization (qep.quantize_params): For each QEP-matched weight:

    • Normalize weight to (ra, ca) format via gptq_core.normalize_weight
    • Apply weight correction: qep_core.weight_correct(W, H, H_delta, correction_factor, dampening_factor)
    • Quantize with GPTQ: gptq_core.quantize_weight(W_corrected, H, how)
    • Fall back to PTQ for non-QEP-matched params
  4. Inference: Uses the same PtqProvider as standard GPTQ/PTQ.

Usage

import jax.numpy as jnp
from qwix._src import model as qwix_model
from qwix._src.providers import ptq
from qwix.contrib import qep

# 1. Define QEP rules.
rules = [qep.QepRule(module_path='Dense_0', weight_qtype=jnp.int8)]

# 2. Create PTQ-dequantized params (needed for the quantized calibration pass).
ptq_provider = ptq.PtqProvider(rules)
ptq_model = qwix_model.quantize_model(model, ptq_provider)
abs_variables = jax.eval_shape(ptq_model.init, jax.random.key(0), x)
ptq_params = ptq.quantize_params(variables['params'], abs_variables['params'])
deq_params = jax.tree.map(
    lambda p: qarray.dequantize(p.array) if isinstance(p, ptq.WithAux) else p,
    ptq_params,
    is_leaf=lambda p: isinstance(p, ptq.WithAux),
)

# 3. QEP calibration (two-pass per batch).
qep_provider = qep.QepCalibrationProvider(rules)
cal_model = qwix_model.quantize_model(model, qep_provider)
cal_variables = dict(variables)
for batch in calibration_data:
    new_vars = qep_provider.calibrate_batch(
        cal_model,
        cal_variables,              # float variables (original weights)
        {'params': deq_params},     # quantized variables (dequantized PTQ weights)
        batch,
    )
    cal_variables.update(new_vars)

# 4. Quantize weights with QEP (weight correction + GPTQ).
qep_params = qep.quantize_params(
    variables['params'],
    abs_variables['params'],
    cal_variables['quant_stats'],
    correction_factor=0.5,
    damping_factor=0.01,
)

# 5. Inference.
output = ptq_model.apply({'params': qep_params}, x)

## Summary

Add QEP (Quantization Error Propagation) support to Qwix, extending GPTQ to account for quantization noise in input activations from previous layers. QEP applies a Hessian-based weight correction before GPTQ quantization, reducing output error when both weights and activations are quantized.

## Algorithm Overview

### The Problem

Standard GPTQ minimizes `||W @ X - W_q @ X||^2`. This represents the error when only weights are quantized and inputs remain in float. In real quantized models, inputs are also quantized by prior layers, so the actual inference error is `||W @ X_float - W_q @ X_q||^2`. GPTQ ignores this input quantization noise, losing out on some accuracy gains.

### QEP Solution

QEP decomposes the quantized-input error into two components:

```
||W @ X_float - W_q @ X_q||^2
= ||(W - W_q) @ X_q  +  W @ (X_float - X_q)||^2
    (weight error)           (input error)
```

To compensate, QEP collects two statistics during calibration:

- **Hessian** `H = X_q @ X_q^T` — input covariance from quantized activations
- **Hessian delta** `H_delta = (X_float - X_q) @ X_q^T` — cross-correlation between input quantization error and quantized activations

Then applies a weight correction before quantization:

```
W_corrected = W + alpha * (W @ H_delta @ inv(H))
```

where `alpha` (correction factor, default 0.5) controls correction strength. This adjusts weights so that `W_corrected @ X_q ~= W @ X_float`, partially canceling the effect of quantized inputs. The corrected weights are then quantized using GPTQ for optimal rounding.

## Implementation

QEP plugs into Qwix's existing interception-based quantization framework:

1. **Configuration**: `QepRule` extends `GptqRule` with two additional parameters:
   - `correction_factor` (default 0.5): controls weight correction strength
   - `dampening_factor` (default 0.01): dampening for QEP Hessian inversion.

2. **Calibration** (`QepCalibrationProvider`): Extends `StatsCalibrationProvider` but overrides `dot_general` directly (rather than using `compute_stats`) to implement a two-pass protocol per batch via `calibrate_batch()`:
   - **Float pass**: Forward pass with original float params; caches float-precision LHS activations in a Python-side dict, keyed by `module_path/var_name`
   - **Quantized pass**: Forward pass with dequantized PTQ params; computes `H` and `H_delta` at each intercepted `dot_general` using cached float LHS and current quantized LHS; accumulates via `SimpleMovingAverage`
   - Must run eagerly (outside `jax.jit`) because the LHS cache is Python-side state

3. **Weight quantization** (`qep.quantize_params`): For each QEP-matched weight:
   - Normalize weight to `(ra, ca)` format via `gptq_core.normalize_weight`
   - Apply weight correction: `qep_core.weight_correct(W, H, H_delta, correction_factor, dampening_factor)`
   - Quantize with GPTQ: `gptq_core.quantize_weight(W_corrected, H, how)`
   - Fall back to PTQ for non-QEP-matched params

4. **Inference**: Uses the same `PtqProvider` as standard GPTQ/PTQ.

### Usage

```python
import jax.numpy as jnp
from qwix._src import model as qwix_model
from qwix._src.providers import ptq
from qwix.contrib import qep

# 1. Define QEP rules.
rules = [qep.QepRule(module_path='Dense_0', weight_qtype=jnp.int8)]

# 2. Create PTQ-dequantized params (needed for the quantized calibration pass).
ptq_provider = ptq.PtqProvider(rules)
ptq_model = qwix_model.quantize_model(model, ptq_provider)
abs_variables = jax.eval_shape(ptq_model.init, jax.random.key(0), x)
ptq_params = ptq.quantize_params(variables['params'], abs_variables['params'])
deq_params = jax.tree.map(
    lambda p: qarray.dequantize(p.array) if isinstance(p, ptq.WithAux) else p,
    ptq_params,
    is_leaf=lambda p: isinstance(p, ptq.WithAux),
)

# 3. QEP calibration (two-pass per batch).
qep_provider = qep.QepCalibrationProvider(rules)
cal_model = qwix_model.quantize_model(model, qep_provider)
cal_variables = dict(variables)
for batch in calibration_data:
    new_vars = qep_provider.calibrate_batch(
        cal_model,
        cal_variables,              # float variables (original weights)
        {'params': deq_params},     # quantized variables (dequantized PTQ weights)
        batch,
    )
    cal_variables.update(new_vars)

# 4. Quantize weights with QEP (weight correction + GPTQ).
qep_params = qep.quantize_params(
    variables['params'],
    abs_variables['params'],
    cal_variables['quant_stats'],
    correction_factor=0.5,
    damping_factor=0.01,
)

# 5. Inference.
output = ptq_model.apply({'params': qep_params}, x)
```
PiperOrigin-RevId: 870737714
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant