This repository provides a unified framework for Reduced-Precision Inference and Training (SFT and QAT) for large language models.
git clone https://github.com/aiha-lab/rp-framework.git
cd rp-framework && git submodule update --init --recursive && cd ..Launch the pre-configured Docker container:
docker run -it --rm --gpus all -p 9077:9000 \
--ipc=host \
-v ${PWD}/rp-framework:/rp-framework \
-v ${PWD}/rp-framework/hf_cache:/rp-framework/hf_cache \
-v /raid:/raid \
superdocker22/rp_framework:1.0 bash # or use your own Docker image with PyTorch 2.6 + CUDA 12.4 (or a similar version)or, PyTorch 2.6 + CUDA 12.4
cd /rp-framework/rp_inference && bash setup.sh
cd /rp-framework/rp_training && pip install -r requirements.txt
If you want to save your current environment as a Docker image, refer to Saving Your Container as a New Image (docker commit).
For models like LLaMA that require a Hugging Face access token, run the following command to authenticate (for the first model download):
huggingface-cli login
# Add token as git credential? (Y/n) nThe logger is disabled by default.
To enable it, sign up for wandb, log in, and modify the config file.
# Console
pip install wandb
wandb login
# In config file
report_to: wandb
# or, Modifying scripts with --report_to wandb
accelerate launch --config_file configs/zero3.yaml train.py \
--gradient_accumulation_steps 1 \
--per_device_train_batch_size 1 \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--w_format fp4_e2m1 \
--output_dir /rp-framework/model_zoo/llama2-7b-mxfp4-w4a16-alpaca-gpt4-nomask-lora128 \
--save_stats \
--report_to wandb \
--config configs/sft_lora_alpaca.yamlUsing --save_stats enables dumping a richer set of statistics, including the mean and standard deviation of inputs, weights, and gradients.
If you encounter the following error during dataset generation:
RuntimeError: Dataset scripts are no longer supported, but found piqa.pyplease downgrade the datasets library to version 3.3.1, as newer versions have removed support for script-based dataset loading.
pip install datasets==3.3.1After downloading the dataset, reinstall datasets==4.1.1 to ensure compatibility with SFT training.
This tutorial covers:
- MXFP4 W4A4 PTQ and QAT on Llama-3.2–1B-Instruct using the PIQA dataset
- MXFP4 W4A16 + BF16 LoRA Fine-tuning (QLoRA-like) on Llama-2-7B using Alpaca dataset (Benchmark: MMLU)
huggingface-cli login
# Add token as git credential? (Y/n) n
pip install datasets==3.3.1
cd /rp-framework && mkdir model_zoo# BF16 PIQA inference on llama3.2-1b-instruct. Expected output: 'acc,none': 0.7393906420021763
cd /rp-framework/rp_inference && bash scripts/run_piqa.sh 0 meta-llama/Llama-3.2-1B-Instruct
# MXFP4 W4A4 PTQ PIQA inference on llama3.2-1b-instruct. Expected output: 'acc,none': 0.6936887921653971
cd /rp-framework/rp_inference && bash scripts/linear_w4a4_piqa.sh 0 meta-llama/Llama-3.2-1B-Instruct# Dataset generation
cd /rp-framework/rp_training/dataset && python gen_piqa_dataset.py
# Full-precision SFT (takes ~20min with A6000-48GBx4)
# Output model will be saved at /rp-framework/model_zoo/llama3.2-1b-instruct-sft
cd /rp-framework/rp_training && accelerate launch --config_file configs/zero3.yaml train.py --config configs/sft_full.yaml --model_name_or_path meta-llama/Llama-3.2-1B-Instruct --output_dir /rp-framework/model_zoo/llama3.2-1b-instruct-sft
# MXFP4 W4A4 QAT from a cold start using an SFT-trained full-precision model (2-step QAT in gpt-oss recipe: High-precision SFT + QAT)
# Output model will be saved at /rp-framework/model_zoo/llama3.2-1b-instruct-sft-qat-w4a4
cd /rp-framework/rp_training && accelerate launch --config_file configs/zero3.yaml train.py --config configs/sft_qat.yaml --model_name_or_path /rp-framework/model_zoo/llama3.2-1b-instruct-sft --output_dir /rp-framework/model_zoo/llama3.2-1b-instruct-sft-qat-w4a4# Evaluate Full-precision SFT model. Expected output: 'acc,none': 0.7475516866158868
cd /rp-framework/rp_inference && bash scripts/run_piqa.sh 0 /rp-framework/model_zoo/llama3.2-1b-instruct-sft
# Evaluate Full-precision SFT + MXFP4 W4A4 QAT model accuracy. Expected output: 'acc,none': 0.7263329706202394
cd /rp-framework/rp_inference && bash scripts/linear_w4a4_piqa.sh 0 /rp-framework/model_zoo/llama3.2-1b-instruct-sft-qat-w4a4# Generate dataset
pip install datasets==4.1.1
cd /rp-framework/rp_training/dataset && python gen_alpaca_dataset.py
# BF16 Weight + BF16 LoRA (taking ~2 hours in A6000x4)
cd /rp-framework/rp_training && accelerate launch --config_file configs/zero3.yaml train.py --gradient_accumulation_steps 1 --per_device_train_batch_size 1 --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir /rp-framework/model_zoo/llama2-7b-alpaca-gpt4-nomask-lora128 --config configs/sft_lora_alpaca.yaml
# MXFP4 W4A16 Weight + BF16 LoRA
cd /rp-framework/rp_training && accelerate launch --config_file configs/zero3.yaml train.py --gradient_accumulation_steps 1 --per_device_train_batch_size 1 --model_name_or_path meta-llama/Llama-2-7b-hf --w_format fp4_e2m1 --output_dir /rp-framework/model_zoo/llama2-7b-mxfp4-w4a16-alpaca-gpt4-nomask-lora128 --config configs/sft_lora_alpaca.yaml# BF16 MMLU inference on llama2-7b. Expected output: 'mmlu': {'acc,none': 0.4126904999287851, ...
cd /rp-framework/rp_inference && bash scripts/run_mmlu.sh 0 meta-llama/Llama-2-7b-hf
# MXFP4 W4A16 inference. Expected output: 'mmlu': {'acc,none': 0.36355220054123344, ...
cd /rp-framework/rp_inference && bash scripts/linear_w4a16_mmlu.sh 0 meta-llama/Llama-2-7b-hf
# BF16 Weight + BF16 LoRA. Expected output: 'mmlu': {'acc,none': 0.4223045150263495, ...
cd /rp-framework/rp_inference && bash scripts/run_mmlu.sh 0 /rp-framework/model_zoo/llama2-7b-alpaca-gpt4-nomask-lora128
# MXFP4 W4A16 Weight + BF16 LoRA. Expected output: 'mmlu': {'acc,none': 0.40008545791197836, ...
cd /rp-framework/rp_inference && bash scripts/linear_w4a16_mmlu.sh 0 /rp-framework/model_zoo/llama2-7b-mxfp4-w4a16-alpaca-gpt4-nomask-lora128### scripts/linear_w4a4_piqa.sh
# Task configuration
tasks=piqa # or winogrande, hellaswag, mmlu, boolq, ...
num_fewshot=none
eval_ppl=false # Set true to evaluate perplexity on Wikitext-2 instead of CSQA
# Linear layer precision (MXFP4)
w_elem_format_linear=fp4_e2m1 # set to "none" for BF16 baseline
a_elem_format_linear=fp4_e2m1
block_size_linear=32
scale_bits_linear=8
w_scale_mode=0 # 0: PoT (Floor), 3: PoT (Round), 152: E5M2
a_scale_mode=0cd /rp-framework/rp_inference
# (1) Baseline inference (no quantization)
bash scripts/run_piqa.sh 0 meta-llama/Llama-3.2-1B-Instruct
# Expected accuracy: ~73.94 on PIQA
# (2) W4A4 inference (PTQ)
bash scripts/linear_w4a4_piqa.sh 0 meta-llama/Llama-3.2-1B-Instruct
# Expected accuracy: ~69.37 on PIQA
# (3) Optional: Multi-GPU inference for larger models
bash scripts/linear_w4a4_piqa.sh 0,1,2,3 meta-llama/Llama-3.2-1B-InstructIf the dataset already exists, you can skip this step.
cd /rp-framework/rp_training/dataset
python gen_piqa_dataset.py
# The folder "piqa-train-llama3.2" will be created in the current directory# rp_training/dataset/gen_piqa_dataset.py
from datasets import load_dataset, DatasetDict
import transformers
from transformers import AddedToken
tokenizer = transformers.AutoTokenizer.from_pretrained('meta-llama/Llama-3.2-1B-Instruct')
def convert_to_language_modeling(example):
goal, sol1, sol2, label = example["goal"], example["sol1"], example["sol2"], example["label"]
# For PIQA: choose between goal+sol1 vs goal+sol2
prompt_val = goal
assistant_val = sol1 if label == 0 else sol2
input_ids = tokenizer.encode(f'{prompt_val} {assistant_val}')
mask = [1 for _ in input_ids]
for j in range(len(tokenizer.encode(prompt_val))):
mask[j] = 0 # Apply loss only on solution part
return {'input_ids': input_ids, 'completion_mask': mask}
raw_train = load_dataset("nthngdy/piqa", split="train")
converted_train = raw_train.map(convert_to_language_modeling)
DatasetDict({"train": converted_train}).save_to_disk("piqa-train-llama3.2")cd /rp-framework/rp_training
accelerate launch --config_file configs/zero3.yaml train.py --config configs/sft_full.yaml --model_name_or_path meta-llama/Llama-3.2-1B-Instruct --output_dir /rp-framework/model_zoo/llama3.2-1b-instruct-sftmodel_name_or_path: meta-llama/Llama-3.2-1B-Instruct
attn_implementation: eager
torch_dtype: bfloat16
# Dataset
dataset_name: dataset/piqa-train-llama3.2
dataset_num_proc: 12
# Hyperparameters
learning_rate: 2.0e-5
gradient_checkpointing: true
num_train_epochs: 1.0
logging_steps: 1
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
gradient_accumulation_steps: 2
max_length: 2048
warmup_ratio: 0.03
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
output_dir:
report_to: wandbcd /rp-framework/rp_inference
bash scripts/linear_w4a4_piqa.sh 0 /rp-framework/model_zoo/llama3.2-1b-instruct-sft
# Expected accuracy: ~75.24 (improved after SFT)cd /rp-framework/rp_training
accelerate launch --config_file configs/zero3.yaml train.py --config configs/sft_qat.yaml --model_name_or_path /rp-framework/model_zoo/llama3.2-1b-instruct-sft --output_dir /rp-framework/model_zoo/llama3.2-1b-instruct-sft-qat-w4a4model_name_or_path: /rp-framework/model_zoo/llama3.2-1b-instruct-sft
# Bit precision (Linear layers)
w_format: fp4_e2m1
a_format: fp4_e2m1
g_format: nullcd /rp-framework/rp_inference
bash scripts/linear_w4a4_piqa.sh 0 /rp-framework/model_zoo/llama3.2-1b-instruct-sft-qat-w4a4
# Expected accuracy: ~73.61 (close to BF16 baseline)| Stage | Model Path | Expected Accuracy | Description |
|---|---|---|---|
| Baseline (BF16) | meta-llama/Llama-3.2-1B-Instruct |
73.94 | No quantization |
| W4A4 PTQ | meta-llama/Llama-3.2-1B-Instruct |
69.37 | Post-training quantization |
| SFT | ./llama3.2-1b-instruct-sft |
75.24 | Fine-tuned with PIQA |
| QAT (W4A4) | ./llama3.2-1b-instruct-sft-qat-w4a4 |
73.61 | Quantization-aware training |