Adds Quantization Error Propagation (QEP) Algorithm#213
Open
copybara-service[bot] wants to merge 1 commit intomainfrom
Open
Adds Quantization Error Propagation (QEP) Algorithm#213copybara-service[bot] wants to merge 1 commit intomainfrom
copybara-service[bot] wants to merge 1 commit intomainfrom
Conversation
## Summary
Add QEP (Quantization Error Propagation) support to Qwix, extending GPTQ to account for quantization noise in input activations from previous layers. QEP applies a Hessian-based weight correction before GPTQ quantization, reducing output error when both weights and activations are quantized.
## Algorithm Overview
### The Problem
Standard GPTQ minimizes `||W @ X - W_q @ X||^2`. This represents the error when only weights are quantized and inputs remain in float. In real quantized models, inputs are also quantized by prior layers, so the actual inference error is `||W @ X_float - W_q @ X_q||^2`. GPTQ ignores this input quantization noise, losing out on some accuracy gains.
### QEP Solution
QEP decomposes the quantized-input error into two components:
```
||W @ X_float - W_q @ X_q||^2
= ||(W - W_q) @ X_q + W @ (X_float - X_q)||^2
(weight error) (input error)
```
To compensate, QEP collects two statistics during calibration:
- **Hessian** `H = X_q @ X_q^T` — input covariance from quantized activations
- **Hessian delta** `H_delta = (X_float - X_q) @ X_q^T` — cross-correlation between input quantization error and quantized activations
Then applies a weight correction before quantization:
```
W_corrected = W + alpha * (W @ H_delta @ inv(H))
```
where `alpha` (correction factor, default 0.5) controls correction strength. This adjusts weights so that `W_corrected @ X_q ~= W @ X_float`, partially canceling the effect of quantized inputs. The corrected weights are then quantized using GPTQ for optimal rounding.
## Implementation
QEP plugs into Qwix's existing interception-based quantization framework:
1. **Configuration**: `QepRule` extends `GptqRule` with two additional parameters:
- `correction_factor` (default 0.5): controls weight correction strength
- `dampening_factor` (default 0.01): dampening for QEP Hessian inversion.
2. **Calibration** (`QepCalibrationProvider`): Extends `StatsCalibrationProvider` but overrides `dot_general` directly (rather than using `compute_stats`) to implement a two-pass protocol per batch via `calibrate_batch()`:
- **Float pass**: Forward pass with original float params; caches float-precision LHS activations in a Python-side dict, keyed by `module_path/var_name`
- **Quantized pass**: Forward pass with dequantized PTQ params; computes `H` and `H_delta` at each intercepted `dot_general` using cached float LHS and current quantized LHS; accumulates via `SimpleMovingAverage`
- Must run eagerly (outside `jax.jit`) because the LHS cache is Python-side state
3. **Weight quantization** (`qep.quantize_params`): For each QEP-matched weight:
- Normalize weight to `(ra, ca)` format via `gptq_core.normalize_weight`
- Apply weight correction: `qep_core.weight_correct(W, H, H_delta, correction_factor, dampening_factor)`
- Quantize with GPTQ: `gptq_core.quantize_weight(W_corrected, H, how)`
- Fall back to PTQ for non-QEP-matched params
4. **Inference**: Uses the same `PtqProvider` as standard GPTQ/PTQ.
### Usage
```python
import jax.numpy as jnp
from qwix._src import model as qwix_model
from qwix._src.providers import ptq
from qwix.contrib import qep
# 1. Define QEP rules.
rules = [qep.QepRule(module_path='Dense_0', weight_qtype=jnp.int8)]
# 2. Create PTQ-dequantized params (needed for the quantized calibration pass).
ptq_provider = ptq.PtqProvider(rules)
ptq_model = qwix_model.quantize_model(model, ptq_provider)
abs_variables = jax.eval_shape(ptq_model.init, jax.random.key(0), x)
ptq_params = ptq.quantize_params(variables['params'], abs_variables['params'])
deq_params = jax.tree.map(
lambda p: qarray.dequantize(p.array) if isinstance(p, ptq.WithAux) else p,
ptq_params,
is_leaf=lambda p: isinstance(p, ptq.WithAux),
)
# 3. QEP calibration (two-pass per batch).
qep_provider = qep.QepCalibrationProvider(rules)
cal_model = qwix_model.quantize_model(model, qep_provider)
cal_variables = dict(variables)
for batch in calibration_data:
new_vars = qep_provider.calibrate_batch(
cal_model,
cal_variables, # float variables (original weights)
{'params': deq_params}, # quantized variables (dequantized PTQ weights)
batch,
)
cal_variables.update(new_vars)
# 4. Quantize weights with QEP (weight correction + GPTQ).
qep_params = qep.quantize_params(
variables['params'],
abs_variables['params'],
cal_variables['quant_stats'],
correction_factor=0.5,
damping_factor=0.01,
)
# 5. Inference.
output = ptq_model.apply({'params': qep_params}, x)
```
PiperOrigin-RevId: 870737714
a855df6 to
dd4d9fc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds Quantization Error Propagation (QEP) Algorithm
Summary
Add QEP (Quantization Error Propagation) support to Qwix, extending GPTQ to account for quantization noise in input activations from previous layers. QEP applies a Hessian-based weight correction before GPTQ quantization, reducing output error when both weights and activations are quantized.
Algorithm Overview
The Problem
Standard GPTQ minimizes
||W @ X - W_q @ X||^2. This represents the error when only weights are quantized and inputs remain in float. In real quantized models, inputs are also quantized by prior layers, so the actual inference error is||W @ X_float - W_q @ X_q||^2. GPTQ ignores this input quantization noise, losing out on some accuracy gains.QEP Solution
QEP decomposes the quantized-input error into two components:
To compensate, QEP collects two statistics during calibration:
H = X_q @ X_q^T— input covariance from quantized activationsH_delta = (X_float - X_q) @ X_q^T— cross-correlation between input quantization error and quantized activationsThen applies a weight correction before quantization:
where
alpha(correction factor, default 0.5) controls correction strength. This adjusts weights so thatW_corrected @ X_q ~= W @ X_float, partially canceling the effect of quantized inputs. The corrected weights are then quantized using GPTQ for optimal rounding.Implementation
QEP plugs into Qwix's existing interception-based quantization framework:
Configuration:
QepRuleextendsGptqRulewith two additional parameters:correction_factor(default 0.5): controls weight correction strengthdampening_factor(default 0.01): dampening for QEP Hessian inversion.Calibration (
QepCalibrationProvider): ExtendsStatsCalibrationProviderbut overridesdot_generaldirectly (rather than usingcompute_stats) to implement a two-pass protocol per batch viacalibrate_batch():module_path/var_nameHandH_deltaat each intercepteddot_generalusing cached float LHS and current quantized LHS; accumulates viaSimpleMovingAveragejax.jit) because the LHS cache is Python-side stateWeight quantization (
qep.quantize_params): For each QEP-matched weight:(ra, ca)format viagptq_core.normalize_weightqep_core.weight_correct(W, H, H_delta, correction_factor, dampening_factor)gptq_core.quantize_weight(W_corrected, H, how)Inference: Uses the same
PtqProvideras standard GPTQ/PTQ.Usage