diff --git a/README.md b/README.md index 210b5f1..ff92a7b 100644 --- a/README.md +++ b/README.md @@ -5,53 +5,5 @@ - Model Behavior-Level Simulation - Hardware-Performance Simulation -**๐Ÿ”– For tutorials and examples, please refer to [this site](https://aicrosssim.github.io/NewComputeBench/)**. +**๐Ÿ”– For milestones, tutorials and examples, please refer to [this site](https://aicrosssim.github.io/NewComputeBench/)**. -## Model Training - -### LLMs - -We adopt Llama-3 architecture and aim to support the following features: - -- Pretraining -- Generation (inference) -- Parameter-efficient fine-tuning -- `๐Ÿšง TODO` `๐ŸŒ LowPriority`: Supervised-fine-tuning -- Evaluation - -#### PreTraining - -The LLM pretraining is built on top of [torchtitan](https://github.com/pytorch/torchtitan). - -- Model architecture: [`Llama3`](/src/torchtitan/models/llama/model.py) -- Model configs: [`60M`, `200M`, `400M`, `1.1B`](src/aixsim_models/llm/model_flavors.py) -- Datasets: [`HuggingFaceFW/fineweb`](/src/aixsim_models/llm/pretrain_data.py) -- HuggingFace checkpoints: [AICrossSim](https://huggingface.co/AICrossSim) - -#### Generation - -We recommend using the HuggingFace Transformers library for generation tasks. -We provide a script to convert the torchtitan checkpoint to a HuggingFace checkpoint (See [this file](/experiments/llm-digital/pretrain/README.md)). - - -#### Parameter-Efficient Fine-tuning -- For models larger than 1.1B, we fine-tune pretrained checkpoints. - - LoRA fine-tuning data - - LoRA fine-tuning scripts - -## Model Behavior Simulation - -- [Random bitflip](/experiments/llm-bitflip/) - - Post-training bitflip transform - - Bitflip-aware pretraining -- Optical compute - - [Roberta on GLUE](/experiments/roberta-optical-transformer/) - - CLM `๐Ÿšง WIP` - -- Spiking neural networks `๐Ÿšง TODO` -- In-memory compute `๐Ÿšง TODO -` - -## Hardware-Performance Simulation - -`๐Ÿšง TODO` diff --git a/docs/02-model-behaviour-level-simulation/clm-bitflip-lora-finetune.md b/docs/02-model-behaviour-level-simulation/clm-bitflip-lora-finetune.md new file mode 100644 index 0000000..d2a7592 --- /dev/null +++ b/docs/02-model-behaviour-level-simulation/clm-bitflip-lora-finetune.md @@ -0,0 +1,222 @@ +# Bitflip-Aware LoRA Fine-Tuning + +This tutorial walks through how to run bitflip-aware LoRA fine-tuning on a pretrained LLM (e.g., `unsloth/Llama-3.1-8B`) using our custom training script. + +## Overview + +Bitflip-aware LoRA fine-tuning combines two ideas: + +1. **Random Bitflip Simulation** โ€” During the forward pass, random bit flips are injected into both activations and weights of every linear layer (except `lm_head`). This emulates hardware-level bit errors that occur in approximate or unreliable compute substrates. +2. **Low-Rank Adaptation (LoRA)** โ€” Instead of fine-tuning all parameters, we attach small low-rank matrices (`lora_A`, `lora_B`) to each linear layer and only train those. The original pretrained weights are frozen. + +By fine-tuning with bitflip noise injected during training, the LoRA adapters learn to compensate for hardware-induced errors, making the model more resilient at inference time. + +### How It Works + +Each `nn.Linear` layer in the model is replaced by a [`BitFlipLinearLora`](https://github.com/AICrossSim/NewComputeBench/blob/master/src/aixsim_models/bitflip/fine_tune/bitflip_lora.py) layer. The forward pass of `BitFlipLinearLora` performs the following: + +``` +Y = bitflip(X) @ bitflip(W + B @ A * scaling)^T + bias +``` + +where: + +- `X` is the input activation (with optional bitflip noise). +- `W` is the frozen pretrained weight. +- `A` (`lora_A`) and `B` (`lora_B`) are the trainable low-rank matrices. +- `scaling = lora_alpha / r` controls the magnitude of the LoRA update. +- `bitflip(ยท)` applies random bit flips to the sign-exponent and mantissa bits of the FP32 representation, controlled by per-component probabilities. + +The model transformation is handled by the [`transform_llama`](https://github.com/AICrossSim/NewComputeBench/blob/master/src/aixsim_models/bitflip/fine_tune/bitflip_llama.py) function, which iterates over all `nn.Linear` modules in the model (excluding `lm_head`) and replaces them with `BitFlipLinearLora`. + +### Entry Points + +| File | Description | +|------|-------------| +| [`experiments/llm-bitflip/lora_finetune/run_clm_no_trainer.py`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/run_clm_no_trainer.py) | Main training script (HuggingFace Accelerate-based, no Trainer) | +| [`experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh) | Shell wrapper that computes training steps and launches the run | +| [`experiments/llm-bitflip/lora_finetune/transform_cfg.toml`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/transform_cfg.toml) | Bitflip + LoRA configuration file | + +## Step-by-Step Guide + +!!! info "Environment Setup" + + If you have not set up environments, please follow the guidelines in [Environment Setup](../env-setup.md). + +### 1. Configure the Bitflip & LoRA Transform + +The transform configuration is defined in a TOML file. Here is the default configuration at [`experiments/llm-bitflip/lora_finetune/transform_cfg.toml`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/transform_cfg.toml): + +```toml +use_lora = true + +[fc] + w_p_exp = 1.52587890625e-05 + w_p_frac = 1.52587890625e-05 + w_zero_out_t = 1.25 + x_p_exp = 1.52587890625e-05 + x_p_frac = 1.52587890625e-05 + x_zero_out_t = 30.0 + +[lora] + r = 32 + lora_alpha = 32 +``` + +**Configuration parameters:** + +| Section | Parameter | Description | +|---------|-----------|-------------| +| (top-level) | `use_lora` | Enable LoRA adaptation (`true`/`false`). When `false`, all parameters are trained. | +| `[fc]` | `w_p_exp` | Bitflip probability for the sign-exponent bits of the **weight**. | +| `[fc]` | `w_p_frac` | Bitflip probability for the mantissa bits of the **weight**. | +| `[fc]` | `w_zero_out_t` | Threshold for zeroing out weight outliers / NaN values. | +| `[fc]` | `x_p_exp` | Bitflip probability for the sign-exponent bits of the **activation**. | +| `[fc]` | `x_p_frac` | Bitflip probability for the mantissa bits of the **activation**. | +| `[fc]` | `x_zero_out_t` | Threshold for zeroing out activation outliers / NaN values. | +| `[lora]` | `r` | LoRA rank. | +| `[lora]` | `lora_alpha` | LoRA scaling factor (effective scaling = `lora_alpha / r`). | + +!!! note "Bitflip probability" + The bitflip probability must be a power of 0.5 (e.g., `0.5^16 โ‰ˆ 1.526e-05`). The kernel automatically snaps to the nearest valid value. Due to limitations of the Philox PRNG, the minimum supported probability is `0.5^24 โ‰ˆ 5.96e-08`. See the [mase-triton docs](../02-model-behaviour-level-simulation/mase-triton.md) for more details. + +### 2. Understand the Training Budget + +The shell script [`fine-tune-bitflip-clm.sh`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh) automatically calculates the number of training steps based on a budget of **1% of the model's parameter count in tokens**. For `unsloth/Llama-3.1-8B` (8B parameters): + +``` +fine-tune tokens = 8,000,000,000 / 100 = 80,000,000 tokens +tokens per step = num_gpus ร— per_device_batch_size ร— block_size +max_train_steps = fine-tune tokens / tokens per step +``` + +For example, with 8 GPUs, batch size 1, and block size 2048: + +``` +tokens per step = 8 ร— 1 ร— 2048 = 16,384 +max_train_steps = 80,000,000 / 16,384 โ‰ˆ 4,883 steps +``` + +### 3. Launch the Fine-Tuning + +```bash +cd experiments/llm-bitflip/lora_finetune +``` + +The script accepts positional arguments to override defaults: + +```bash +./fine-tune-bitflip-clm.sh [num_processes] [model_name_or_path] [per_device_train_batch_size] [learning_rate] [weight_decay] [gradient_accumulation_steps] [block_size] +``` + +**Example: Fine-tune Llama-3.1-8B on 8 GPUs with default settings** + +```bash +./fine-tune-bitflip-clm.sh 8 unsloth/Llama-3.1-8B 1 1e-5 0.01 2 2048 +``` + +This is equivalent to running the underlying command directly: + +```bash +uv run accelerate launch --num_processes=8 \ + run_clm_no_trainer.py \ + --model_name_or_path unsloth/Llama-3.1-8B \ + --dataset_name Cheng98/fineweb-edu-1.25B \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --learning_rate 1e-5 \ + --weight_decay 0.01 \ + --num_train_epochs 1 \ + --gradient_accumulation_steps 2 \ + --lr_scheduler_type linear \ + --output_dir ./output/Llama-3.1-8B-bitflip-lora \ + --preprocessing_num_workers 32 \ + --trust_remote_code \ + --with_tracking \ + --report_to wandb \ + --transform_cfg ./transform_cfg.toml \ + --block_size 2048 \ + --log_train_loss_steps 50 \ + --max_train_steps 4883 \ + --wandb_tags unsloth/Llama-3.1-8B,lr1e-5,steps4883 +``` + +**Key arguments:** + +| Argument | Description | +|----------|-------------| +| `--model_name_or_path` | HuggingFace model identifier or local path. | +| `--dataset_name` | Training dataset. We use a 1.25B-token subset of [FineWeb-Edu](https://huggingface.co/datasets/Cheng98/fineweb-edu-1.25B). | +| `--transform_cfg` | Path to the TOML config for bitflip + LoRA. | +| `--block_size` | Context length for training samples. | +| `--log_train_loss_steps` | Log training loss to W&B every N steps. | +| `--max_train_steps` | Total number of optimizer steps (auto-calculated by the shell script). | + +!!! tip "Adjusting GPU count" + The first argument to `fine-tune-bitflip-clm.sh` controls `--num_processes` for `accelerate launch`. The script automatically recalculates `max_train_steps` to maintain the same total token budget regardless of the number of GPUs. + +### 4. Monitor Training + +If you have W&B set up (`wandb login`), training loss and validation perplexity are logged automatically. The training logs to the W&B project `Bitflip-CLM-Fine-tune`. + +- **Training loss** is logged every 50 steps (configurable via `--log_train_loss_steps`). +- **Validation perplexity** is evaluated at the end of each epoch on the first 64 batches of the validation set. + +### 5. Output + +After training completes, the fine-tuned model (with LoRA weights merged into the base model) and tokenizer are saved to the output directory: + +``` +./output/Llama-3.1-8B-bitflip-lora/ +โ”œโ”€โ”€ config.json +โ”œโ”€โ”€ model.safetensors +โ”œโ”€โ”€ tokenizer.json +โ”œโ”€โ”€ tokenizer_config.json +โ””โ”€โ”€ all_results.json # Final perplexity +``` + +## Results + +!!! warning "Results Pending" + The following results are placeholders and will be updated once experiments complete. + +### Training Curves + +![Bitflip LoRA Fine-Tuning Curves](../images/bitflip/7b-lora-trainloss.png){ width=720px } + + +| Metric | Value | +|--------|-------| +| Final Training Loss | *2.50* | +| Final Validation Perplexity | *11.01* | +| Total Training Steps | *4883* | + +### Comparison with Baselines + +We evaluate the model under three conditions: + +| Bitflipped | Fine-tuned | Bitflip Config | Fine-tune Config | Train PPL | +|-------|---------------|------------------| ---------| ----| +| โœ˜ | โœ˜ | N/A | N/A | *7.91* | +| โœ” | โœ˜ | `w/x_p_exp=1.53e-5, w/x_p_frac=1.53e-5`| N/A | *1008.95* | +| โœ” | โœ” | `w/x_p_exp=1.53e-5, w/x_p_frac=1.53e-5` | Lora rank=32 | *11.01* | + +From the table above, we can see that *Lora fine-tuning effectively mitigates the impact of bitflip noise, reducing perplexity from 1008.95 to 11.01* for a 7B model. + +We can also safely assume that with more trainable parameters (e.g., a larger LoRA rank, or full fine-tuning) the model would be able to compensate for the noise even better. + +### Resources + +| Resource | Link | +|----------|------| +| W&B Logs | *https://wandb.ai/cz98/Bitflip-CLM-Fine-tune* | +| Training Config | [`transform_cfg.toml`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/transform_cfg.toml) | + +## Appendix: Evaluation Scripts + +The comparison table above was generated with two evaluation-only wrappers that reuse `run_clm_no_trainer.py` but bypass any optimizer steps. Both scripts share the signature `./script.sh [num_processes] [model_name_or_path] [per_device_batch_size] [block_size] [eval_max_steps]` so you can sweep models or batch sizes without editing Python code. + +| Script | Purpose | Notes | +|--------|---------|-------| +| [`experiments/llm-bitflip/lora_finetune/eval-bitflip-no-finetune.sh`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/eval-bitflip-no-finetune.sh) | Measures perplexity when random bitflips are injected during inference. | This is biflipped (โœ”) fine-tuned (โœ˜) entry | +| [`experiments/llm-bitflip/lora_finetune/eval-no-biflip-no-finetune.sh`](https://github.com/AICrossSim/NewComputeBench/blob/master/experiments/llm-bitflip/lora_finetune/eval-no-biflip-no-finetune.sh) | Serves as the clean baseline (no injected bitflips, no finetuning) so we can isolate the effect of noise. | This is biflip-free (โœ˜) fine-tuned (โœ˜) entry | diff --git a/docs/images/bitflip/7b-lora-trainloss.png b/docs/images/bitflip/7b-lora-trainloss.png new file mode 100644 index 0000000..601d2cb Binary files /dev/null and b/docs/images/bitflip/7b-lora-trainloss.png differ diff --git a/docs/index.md b/docs/index.md index 33b4fe6..3556a3f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,16 +12,24 @@ - [x] Filter out promising new compute paradigms by running small & medium scale experiments (Roberta on GLUE) - [ ] Scale up the promising new compute paradigms to large-scale language models - [ ] Fine-tuning/pretraining of CLM models (60M - 1.1B) + - [x] Random bitflip - [x] Optical compute - [ ] Spiking neural networks - [ ] In-memory compute - [ ] Parameter-efficient fine-tuning of larger LLMs (e.g., Llama-3.1-8B) + - [x] Random bitflip (promising results) - [x] Optical compute (failed to converge) ## What's New -- ๐Ÿšง**4th Oct, 2025 Milestone**: Fine-tuning/pretraining of alternative compute paradigms on CLMs. +- **4th, Feb, 2026 Milestone**: We have successfully fine-tuned Llama-3.1-8B with random bitflip noise injected in forward passes, and observed promising results that the LoRA adapters with only 1.2% trainable parameters can effectively mitigate the effect of noise (reducing perplexity from 1008.95 to 11.01, with the original clean perplexity at 7.91). + + | Item | Description | + | ---- | ----------- | + | Llama-3.1-8B with random bitflip noise | [Tutorial](./02-model-behaviour-level-simulation/clm-bitflip-lora-finetune.md) + +- **4th Oct, 2025 Milestone**: Fine-tuning/pretraining of alternative compute paradigms on CLMs. | Item | Description | | ---- | ----------- | diff --git a/experiments/llm-bitflip/lora_finetune/eval-bitflip-no-finetune.sh b/experiments/llm-bitflip/lora_finetune/eval-bitflip-no-finetune.sh new file mode 100644 index 0000000..cce79fc --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/eval-bitflip-no-finetune.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Evaluation-only run on train/validation splits with bitflip LoRA transform (no trainable params). +# Usage: ./eval-bitflip-clm.sh [num_processes] [model_name_or_path] [per_device_batch_size] [block_size] [eval_max_steps] + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) +RUN_SCRIPT="${SCRIPT_DIR}/run_clm_no_trainer.py" +TRANSFORM_CFG="${SCRIPT_DIR}/transform_cfg.toml" + +NUM_PROCESSES=${1:-8} +MODEL_NAME_OR_PATH=${2:-"unsloth/Llama-3.1-8B"} +PER_DEVICE_BATCH_SIZE=${3:-1} +BLOCK_SIZE=${4:-2048} +EVAL_MAX_STEPS=${5:-64} + +OUTPUT_DIR="${SCRIPT_DIR}/output/$(basename ${MODEL_NAME_OR_PATH})-bitflip-lora-eval" +WANDB_TAGS="${MODEL_NAME_OR_PATH},bitflip,eval" + +echo "============================================" +echo "Evaluation Only (Bitflip LoRA):" +echo "============================================" +echo "Model: ${MODEL_NAME_OR_PATH}" +echo "Number of Processes: ${NUM_PROCESSES}" +echo "Per Device Batch Size: ${PER_DEVICE_BATCH_SIZE}" +echo "Block Size: ${BLOCK_SIZE}" +if [ "${EVAL_MAX_STEPS}" -gt 0 ]; then + echo "Eval Max Steps per split: ${EVAL_MAX_STEPS}" +else + echo "Eval Max Steps per split: full dataset" +fi +echo "Output Directory: ${OUTPUT_DIR}" +echo "Wandb Tags: ${WANDB_TAGS}" +echo "============================================" + +uv run accelerate launch --num_processes=${NUM_PROCESSES} \ + "${RUN_SCRIPT}" \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --dataset_name Cheng98/fineweb-edu-1.25B \ + --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ + --per_device_eval_batch_size ${PER_DEVICE_BATCH_SIZE} \ + --num_train_epochs 1 \ + --gradient_accumulation_steps 1 \ + --lr_scheduler_type linear \ + --output_dir ${OUTPUT_DIR} \ + --preprocessing_num_workers 32 \ + --trust_remote_code \ + --with_tracking \ + --report_to wandb \ + --transform_cfg "${TRANSFORM_CFG}" \ + --block_size ${BLOCK_SIZE} \ + --eval_only \ + --eval_max_steps ${EVAL_MAX_STEPS} \ + --wandb_tags ${WANDB_TAGS} diff --git a/experiments/llm-bitflip/lora_finetune/eval-no-biflip-no-finetune.sh b/experiments/llm-bitflip/lora_finetune/eval-no-biflip-no-finetune.sh new file mode 100644 index 0000000..a43dea5 --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/eval-no-biflip-no-finetune.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Evaluation-only run on train/validation splits with baseline LoRA (no bitflip, no trainable params). +# Usage: ./eval-lora-baseline.sh [num_processes] [model_name_or_path] [per_device_batch_size] [block_size] [eval_max_steps] + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) +RUN_SCRIPT="${SCRIPT_DIR}/run_clm_no_trainer.py" +TRANSFORM_CFG="${SCRIPT_DIR}/transform_cfg_baseline.toml" + +NUM_PROCESSES=${1:-8} +MODEL_NAME_OR_PATH=${2:-"unsloth/Llama-3.1-8B"} +PER_DEVICE_BATCH_SIZE=${3:-1} +BLOCK_SIZE=${4:-2048} +EVAL_MAX_STEPS=${5:-64} + +OUTPUT_DIR="${SCRIPT_DIR}/output/$(basename ${MODEL_NAME_OR_PATH})-lora-baseline-eval" +WANDB_TAGS="${MODEL_NAME_OR_PATH},baseline,eval" + +echo "============================================" +echo "Evaluation Only (LoRA Baseline, No Bitflip):" +echo "============================================" +echo "Model: ${MODEL_NAME_OR_PATH}" +echo "Number of Processes: ${NUM_PROCESSES}" +echo "Per Device Batch Size: ${PER_DEVICE_BATCH_SIZE}" +echo "Block Size: ${BLOCK_SIZE}" +if [ "${EVAL_MAX_STEPS}" -gt 0 ]; then + echo "Eval Max Steps per split: ${EVAL_MAX_STEPS}" +else + echo "Eval Max Steps per split: full dataset" +fi +echo "Output Directory: ${OUTPUT_DIR}" +echo "Wandb Tags: ${WANDB_TAGS}" +echo "============================================" + +uv run accelerate launch --num_processes=${NUM_PROCESSES} \ + "${RUN_SCRIPT}" \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --dataset_name Cheng98/fineweb-edu-1.25B \ + --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ + --per_device_eval_batch_size ${PER_DEVICE_BATCH_SIZE} \ + --num_train_epochs 1 \ + --gradient_accumulation_steps 1 \ + --lr_scheduler_type linear \ + --output_dir ${OUTPUT_DIR} \ + --preprocessing_num_workers 32 \ + --trust_remote_code \ + --with_tracking \ + --report_to wandb \ + --transform_cfg "${TRANSFORM_CFG}" \ + --block_size ${BLOCK_SIZE} \ + --eval_only \ + --eval_max_steps ${EVAL_MAX_STEPS} \ + --wandb_tags ${WANDB_TAGS} diff --git a/experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh b/experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh new file mode 100755 index 0000000..bfe75f5 --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/fine-tune-bitflip-clm.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +# Parameterized fine-tuning script with proper max_train_steps calculation +# Usage: ./fine-tune-bitflip-clm.sh [num_processes] [model_name_or_path] [per_device_train_batch_size] [learning_rate] [weight_decay] [gradient_accumulation_steps] [block_size] + +# Default parameters +NUM_PROCESSES=${1:-8} +MODEL_NAME_OR_PATH=${2:-"unsloth/Llama-3.1-8B"} +PER_DEVICE_TRAIN_BATCH_SIZE=${3:-1} +LEARNING_RATE=${4:-"1e-5"} +WEIGHT_DECAY=${5:-"0.01"} +GRADIENT_ACCUMULATION_STEPS=${6:-2} +BLOCK_SIZE=${7:-2048} + +# Function to get model parameters count +get_model_params() { + case "$1" in + "AICrossSim/clm-60m") + echo "60000000" + ;; + "AICrossSim/clm-200m") + echo "200000000" + ;; + "AICrossSim/clm-400m") + echo "400000000" + ;; + "AICrossSim/clm-600m") + echo "600000000" + ;; + "AICrossSim/clm-1.1b") + echo "1100000000" + ;; + "unsloth/Llama-3.1-8B") + echo "8000000000" + ;; + *) + echo "Unknown model: $1" >&2 + exit 1 + ;; + esac +} + +# Calculate derived parameters +N_PARAMS=$(get_model_params "$MODEL_NAME_OR_PATH") +N_FINE_TUNE_TOKENS=$((1 * N_PARAMS / 100)) +N_SAMPLES_PER_STEP=$((NUM_PROCESSES * PER_DEVICE_TRAIN_BATCH_SIZE)) +N_TOKENS_PER_STEP=$((N_SAMPLES_PER_STEP * BLOCK_SIZE)) + +# Calculate max_train_steps using ceiling division: (a + b - 1) / b +MAX_TRAIN_STEPS=$(((N_FINE_TUNE_TOKENS + N_TOKENS_PER_STEP - 1) / N_TOKENS_PER_STEP)) + +echo "Calculated max_train_steps: ${MAX_TRAIN_STEPS}" + + +# Generate output directory name +OUTPUT_DIR="./output/$(basename ${MODEL_NAME_OR_PATH})-bitflip-lora" + +# Generate wandb tags +WANDB_TAGS="${MODEL_NAME_OR_PATH},lr${LEARNING_RATE},steps${MAX_TRAIN_STEPS}" + +echo "============================================" +echo "Fine-tuning Configuration:" +echo "============================================" +echo "Model: ${MODEL_NAME_OR_PATH}" +echo "Model Parameters: ${N_PARAMS}" +echo "Number of Processes: ${NUM_PROCESSES}" +echo "Per Device Train Batch Size: ${PER_DEVICE_TRAIN_BATCH_SIZE}" +echo "Learning Rate: ${LEARNING_RATE}" +echo "Weight Decay: ${WEIGHT_DECAY}" +echo "Gradient Accumulation Steps: ${GRADIENT_ACCUMULATION_STEPS}" +echo "Block Size: ${BLOCK_SIZE}" +echo "" +echo "Calculated Parameters:" +echo "Fine-tune Tokens: ${N_FINE_TUNE_TOKENS}" +echo "Samples per Step: ${N_SAMPLES_PER_STEP}" +echo "Tokens per Step: ${N_TOKENS_PER_STEP}" +echo "Max Train Steps: ${MAX_TRAIN_STEPS}" +echo "Output Directory: ${OUTPUT_DIR}" +echo "Wandb Tags: ${WANDB_TAGS}" +echo "============================================" + +# Run the training +uv run accelerate launch --num_processes=${NUM_PROCESSES} \ + run_clm_no_trainer.py \ + --model_name_or_path ${MODEL_NAME_OR_PATH} \ + --dataset_name Cheng98/fineweb-edu-1.25B \ + --per_device_train_batch_size ${PER_DEVICE_TRAIN_BATCH_SIZE} \ + --per_device_eval_batch_size ${PER_DEVICE_TRAIN_BATCH_SIZE} \ + --learning_rate ${LEARNING_RATE} \ + --weight_decay ${WEIGHT_DECAY} \ + --num_train_epochs 1 \ + --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEPS} \ + --lr_scheduler_type linear \ + --output_dir ${OUTPUT_DIR} \ + --preprocessing_num_workers 32 \ + --trust_remote_code \ + --with_tracking \ + --report_to wandb \ + --transform_cfg ./transform_cfg.toml \ + --block_size ${BLOCK_SIZE} \ + --log_train_loss_steps 50 \ + --max_train_steps ${MAX_TRAIN_STEPS} \ + --wandb_tags ${WANDB_TAGS} diff --git a/experiments/llm-bitflip/lora_finetune/plot_train_loss.py b/experiments/llm-bitflip/lora_finetune/plot_train_loss.py new file mode 100644 index 0000000..2378baf --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/plot_train_loss.py @@ -0,0 +1,55 @@ +import argparse +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import pandas as pd + + +def plot_train_loss(csv_path: Path, output_path: Optional[Path] = None) -> None: + """Plot train loss vs step and save the figure.""" + df = pd.read_csv(csv_path) + + step_col = "Step" + loss_col = "Llama-3.1-8B-bitflip-lora-r32 - train_loss" + baseline_loss = 2.06842 + + if step_col not in df.columns: + raise ValueError(f"Missing column: {step_col}") + if loss_col not in df.columns: + raise ValueError(f"Missing column: {loss_col}") + + fig, ax = plt.subplots(figsize=(8, 4.5)) + ax.plot( + df[step_col], df[loss_col], label="bitflip r32", color="tab:blue", linewidth=1.8 + ) + ax.axhline( + baseline_loss, color="tab:red", linestyle="--", linewidth=1.2, label="original" + ) + + ax.set_xlabel("Train step") + ax.set_ylabel("Train loss") + ax.set_title("Llama-3.1-8B bitflip lora Fine-tuning") + ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.7) + ax.legend() + fig.tight_layout() + + if output_path is None: + output_path = csv_path.with_suffix(".png") + output_path.parent.mkdir(parents=True, exist_ok=True) + + fig.savefig(output_path, dpi=200) + print(f"Saved plot to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Plot train loss vs step from a W&B CSV export." + ) + parser.add_argument("csv", type=Path, help="Path to the W&B CSV export") + parser.add_argument( + "--output", "-o", type=Path, default=None, help="Path to save the plot (PNG)" + ) + args = parser.parse_args() + + plot_train_loss(args.csv, args.output) diff --git a/experiments/llm-bitflip/lora_finetune/run_clm_no_trainer.py b/experiments/llm-bitflip/lora_finetune/run_clm_no_trainer.py new file mode 100644 index 0000000..d3d8312 --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/run_clm_no_trainer.py @@ -0,0 +1,1087 @@ +#!/usr/bin/env python +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "transformers @ git+https://github.com/huggingface/transformers.git", +# "albumentations >= 1.4.16", +# "accelerate >= 0.12.0", +# "torch >= 1.3", +# "datasets >= 2.14.0", +# "sentencepiece != 0.1.92", +# "protobuf", +# "evaluate", +# "scikit-learn", +# ] +# /// + +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" + +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import argparse +import json +import logging +import math +import os +import random +import sys +from itertools import chain +from pathlib import Path + +sys.path.append(Path(__file__).resolve().parents[3].joinpath("src").as_posix()) + +import datasets +import tomllib +import torch +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import HfApi +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + LlamaForCausalLM, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + +from aixsim_models.bitflip.fine_tune.bitflip_llama import transform_llama + +FC_CFG = dict( + x_p_exp=None, + x_p_frac=None, + x_zero_out_t=None, + w_p_exp=None, + w_p_frac=None, + w_zero_out_t=None, + x_seed_exp=0, + x_seed_frac=0, + w_seed_exp=0, + w_seed_frac=0, +) + +LORA_CFG = dict(r=32, lora_alpha=32) + +logger = get_logger(__name__) + +require_version( + "datasets>=2.14.0", + "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt", +) + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Finetune a transformers model on a causal language modeling task" + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", + type=str, + default=None, + help="A csv, txt or a json file containing the training data.", + ) + parser.add_argument( + "--validation_file", + type=str, + default=None, + help="A csv, txt or a json file containing the validation data.", + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the ๐Ÿค— Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--weight_decay", type=float, default=0.0, help="Weight decay to use." + ) + parser.add_argument( + "--num_train_epochs", + type=int, + default=3, + help="Total number of training epochs to perform.", + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + ], + ) + parser.add_argument( + "--num_warmup_steps", + type=int, + default=0, + help="Number of steps for the warmup in the lr scheduler.", + ) + parser.add_argument( + "--output_dir", type=str, default=None, help="Where to store the final model." + ) + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", + action="store_true", + help="Overwrite the cached training and evaluation sets", + ) + parser.add_argument( + "--no_keep_linebreaks", + action="store_true", + help="Do not keep line breaks when using TXT files.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_token", type=str, help="The token to use to push to the Model Hub." + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help=( + "Whether to trust the execution of code from datasets/models defined on the Hub." + " This option should only be set to `True` for repositories you trust and in which you have read the" + " code, as it will execute code present on the Hub on your local machine." + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. ' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--transform_cfg", + type=Path, + default=None, + help="Path to the transform configuration file.", + ) + parser.add_argument( + "--log_train_loss_steps", + type=int, + default=None, + help="Log training loss every n steps. If not provided, training loss will be logged at the end of each epoch.", + ) + parser.add_argument( + "--wandb_tags", + type=lambda s: s.split(","), + default=None, + help="A comma-separated list of tags to apply to the W&B run.", + ) + parser.add_argument( + "--eval_only", + action="store_true", + help="Run evaluation on train/validation splits without any training or trainable parameters.", + ) + parser.add_argument( + "--eval_max_steps", + type=int, + default=64, + help="Maximum number of evaluation batches per split. Use -1 to evaluate the full split.", + ) + args = parser.parse_args() + + if args.eval_max_steps is not None and args.eval_max_steps < 0: + args.eval_max_steps = None + + # Sanity checks + if ( + args.dataset_name is None + and args.train_file is None + and args.validation_file is None + ): + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`train_file` should be a csv, json or txt file.") + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + if extension not in ["csv", "json", "txt"]: + raise ValueError("`validation_file` should be a csv, json or txt file.") + + if args.push_to_hub: + if args.output_dir is None: + raise ValueError( + "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + ) + + return args + + +def main(): + args = parse_args() + + transform_cfg = None + use_lora = False + fc_cfg = None + if args.transform_cfg is not None: + with open(args.transform_cfg, "rb") as f: + transform_cfg = tomllib.load(f) + use_lora = transform_cfg["use_lora"] + fc_cfg = FC_CFG | transform_cfg["fc"] + lora_cfg = LORA_CFG | transform_cfg["lora"] + if use_lora: + fc_cfg = fc_cfg | lora_cfg + + if use_lora: + + def set_trainable(model: torch.nn.Module): + trainable = [] + n_params = 0 + total_params = 0 + for n, p in model.named_parameters(): + if "lora_" in n.lower(): + p.requires_grad = True + trainable.append(n) + n_params += p.numel() + else: + p.requires_grad = False + total_params += p.numel() + logger.info( + f"Number of trainable parameters: {n_params:,} ({100 * n_params / total_params:.2f}%)\nTrainable parameters: {trainable}" + ) + + else: + + def set_trainable(model: torch.nn.Module): + trainable = [] + total_params = 0 + for n, p in model.named_parameters(): + p.requires_grad = True + trainable.append(n) + total_params += p.numel() + logger.info( + f"Number of trainable parameters: {total_params:,}\nTrainable parameters: {trainable}" + ) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["project_dir"] = args.output_dir + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + **accelerator_log_kwargs, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + # Retrieve of infer repo_name + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + # Create repo and retrieve repo_id + api = HfApi() + repo_id = api.create_repo( + repo_name, exist_ok=True, token=args.hub_token + ).repo_id + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + args.dataset_name, + args.dataset_config_name, + trust_remote_code=args.trust_remote_code, + ) + if "validation" not in raw_datasets: + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + trust_remote_code=args.trust_remote_code, + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + trust_remote_code=args.trust_remote_code, + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + extension = args.train_file.split(".")[-1] + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.validation_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets: + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained( + args.config_name, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained( + args.model_name_or_path, + trust_remote_code=args.trust_remote_code, + ) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer_name, + use_fast=not args.use_slow_tokenizer, + trust_remote_code=args.trust_remote_code, + ) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, + use_fast=not args.use_slow_tokenizer, + trust_remote_code=args.trust_remote_code, + ) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script. " + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=args.trust_remote_code, + ) + + assert isinstance(model, LlamaForCausalLM), "model must be a LlamaForCausalLM model" + if transform_cfg is not None: + print("Transforming model...") + replaced_layers = transform_llama(model, fc_cfg, use_lora=use_lora) + print(f"Replaced {len(replaced_layers)} layers") + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with accelerator.main_process_first(): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > config.max_position_embeddings: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + f"Using block_size={min(1024, config.max_position_embeddings)} instead. You can change that default value by passing --block_size xxx." + ) + block_size = min(1024, config.max_position_embeddings) + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model " + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/process#map + + with accelerator.main_process_first(): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, + shuffle=not args.eval_only, + collate_fn=default_data_collator, + batch_size=args.per_device_train_batch_size, + ) + eval_dataloader = DataLoader( + eval_dataset, + collate_fn=default_data_collator, + batch_size=args.per_device_eval_batch_size, + ) + + if args.eval_only: + # Ensure nothing is marked trainable and skip optimizer/scheduler setup. + for p in model.parameters(): + p.requires_grad = False + + model, train_dataloader, eval_dataloader = accelerator.prepare( + model, train_dataloader, eval_dataloader + ) + + if args.with_tracking: + experiment_config = vars(args) + experiment_config["lr_scheduler_type"] = experiment_config[ + "lr_scheduler_type" + ].value + experiment_config["bitflip_config"] = { + "use_lora": use_lora, + "fc_cfg": fc_cfg, + } + accelerator.init_trackers( + "Bitflip-CLM-Eval", + experiment_config, + init_kwargs={ + "wandb": { + "name": args.output_dir.split("/")[-1], + "tags": args.wandb_tags if args.wandb_tags is not None else [], + }, + }, + ) + + def evaluate_split(split_name: str, dataloader: DataLoader): + model.eval() + losses = [] + for step, batch in enumerate(dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append( + accelerator.gather_for_metrics( + loss.repeat(args.per_device_eval_batch_size) + ) + ) + + if args.eval_max_steps is not None and args.eval_max_steps > 0: + if (step + 1) >= args.eval_max_steps: + break + + if len(losses) == 0: + logger.warning( + f"No batches processed for {split_name} split during evaluation." + ) + return None + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + eval_loss = torch.mean(losses) + perplexity = float("inf") + + logger.info( + f"{split_name} perplexity: {perplexity} loss: {eval_loss} (steps={(len(losses) // args.per_device_eval_batch_size)})" + ) + + if args.with_tracking: + accelerator.log( + { + f"{split_name}_perplexity": perplexity, + f"{split_name}_loss": eval_loss, + }, + step=0, + ) + + return eval_loss, perplexity + + evaluate_split("train", train_dataloader) + evaluate_split("validation", eval_dataloader) + accelerator.wait_for_everyone() + accelerator.end_training() + return + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + set_trainable(model) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) and p.requires_grad + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) and p.requires_grad + ], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, + num_training_steps=( + args.max_train_steps + if overrode_max_train_steps + else args.max_train_steps * accelerator.num_processes + ), + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = ( + accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config[ + "lr_scheduler_type" + ].value + experiment_config["bitflip_config"] = { + "use_lora": use_lora, + "fc_cfg": fc_cfg, + } + accelerator.init_trackers( + "Bitflip-CLM-Fine-tune", + experiment_config, + init_kwargs={ + "wandb": { + "name": args.output_dir.split("/")[-1], + "tags": args.wandb_tags if args.wandb_tags is not None else [], + }, + }, + ) + + # Train! + total_batch_size = ( + args.per_device_train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info( + f" Instantaneous batch size per device = {args.per_device_train_batch_size}" + ) + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(args.max_train_steps), disable=not accelerator.is_local_main_process + ) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + checkpoint_path = args.resume_from_checkpoint + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[ + -1 + ] # Sorts folders by date modified, most recent checkpoint is the last + checkpoint_path = path + path = os.path.basename(checkpoint_path) + + accelerator.print(f"Resumed from checkpoint: {checkpoint_path}") + accelerator.load_state(checkpoint_path) + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + completed_steps = starting_epoch * num_update_steps_per_epoch + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = ( + int(training_difference.replace("step_", "")) + * args.gradient_accumulation_steps + ) + starting_epoch = resume_step // len(train_dataloader) + completed_steps = resume_step // args.gradient_accumulation_steps + resume_step -= starting_epoch * len(train_dataloader) + + # update the progress_bar if load from checkpoint + progress_bar.update(completed_steps) + + # Initialize step-based loss tracking + if args.with_tracking and args.log_train_loss_steps is not None: + step_losses = [] + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking and args.log_train_loss_steps is None: + total_loss = 0 + if ( + args.resume_from_checkpoint + and epoch == starting_epoch + and resume_step is not None + ): + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches( + train_dataloader, resume_step + ) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss + if args.with_tracking: + if args.log_train_loss_steps is not None: + # Step-based logging + step_losses.append(loss.detach().float()) + else: + # Epoch-based logging (original behavior) + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + # Log training loss every n steps if specified + if ( + args.with_tracking + and args.log_train_loss_steps is not None + and completed_steps % args.log_train_loss_steps == 0 + and len(step_losses) > 0 + ): + avg_train_loss = torch.mean(torch.stack(step_losses)).item() + accelerator.log( + { + "train_loss": avg_train_loss, + "step": completed_steps, + }, + step=completed_steps, + ) + logger.info( + f"Step {completed_steps}: train_loss: {avg_train_loss:.4f}" + ) + step_losses = [] # Reset for next interval + + if isinstance(checkpointing_steps, int): + if ( + completed_steps % checkpointing_steps == 0 + and accelerator.sync_gradients + ): + output_dir = f"step_{completed_steps}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append( + accelerator.gather_for_metrics( + loss.repeat(args.per_device_eval_batch_size) + ) + ) + if args.eval_max_steps is not None and args.eval_max_steps > 0: + if (step + 1) >= args.eval_max_steps: + break + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") + + if args.with_tracking: + log_dict = { + "perplexity": perplexity, + "eval_loss": eval_loss, + "epoch": epoch, + "step": completed_steps, + } + # Only log epoch-based train_loss if not using step-based logging + if args.log_train_loss_steps is None: + log_dict["train_loss"] = total_loss.item() / len(train_dataloader) + + accelerator.log(log_dict, step=completed_steps) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + api.upload_folder( + commit_message=f"Training in progress epoch {epoch}", + folder_path=args.output_dir, + repo_id=repo_id, + repo_type="model", + token=args.hub_token, + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + api.upload_folder( + commit_message="End of training", + folder_path=args.output_dir, + repo_id=repo_id, + repo_type="model", + token=args.hub_token, + ) + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + + # Log any remaining step losses that haven't been logged yet + if ( + args.with_tracking + and args.log_train_loss_steps is not None + and len(step_losses) > 0 + ): + avg_train_loss = torch.mean(torch.stack(step_losses)).item() + accelerator.log( + { + "train_loss": avg_train_loss, + "step": completed_steps, + }, + step=completed_steps, + ) + logger.info( + f"Final Step {completed_steps}: train_loss: {avg_train_loss:.4f} (remaining {len(step_losses)} steps)" + ) + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/experiments/llm-bitflip/lora_finetune/transform_cfg.toml b/experiments/llm-bitflip/lora_finetune/transform_cfg.toml new file mode 100644 index 0000000..d19d1df --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/transform_cfg.toml @@ -0,0 +1,13 @@ +use_lora = true + +[fc] + w_p_exp = 1.52587890625e-05 + w_p_frac = 1.52587890625e-05 + w_zero_out_t = 1.25 + x_p_exp = 1.52587890625e-05 + x_p_frac = 1.52587890625e-05 + x_zero_out_t = 30.0 + +[lora] + r = 32 + lora_alpha = 32 diff --git a/experiments/llm-bitflip/lora_finetune/transform_cfg_baseline.toml b/experiments/llm-bitflip/lora_finetune/transform_cfg_baseline.toml new file mode 100644 index 0000000..a4b26cf --- /dev/null +++ b/experiments/llm-bitflip/lora_finetune/transform_cfg_baseline.toml @@ -0,0 +1,8 @@ +use_lora = true + +[fc] + # No bitflip โ€” all probabilities remain at default (None) + +[lora] + r = 32 + lora_alpha = 32 diff --git a/experiments/llm-bitflip/pretrain/run.py b/experiments/llm-bitflip/pretrain/run.py index d72e4a4..7d2aff0 100644 --- a/experiments/llm-bitflip/pretrain/run.py +++ b/experiments/llm-bitflip/pretrain/run.py @@ -8,16 +8,15 @@ import math from pathlib import Path -import torch from aixsim_models.llm.profiler import profile_num_params from aixsim_models.llm import register_model_configs, register_pretrain_dataset from aixsim_models.utils.logging import set_logging_verbosity from aixsim_models.llm.evaluator import pt_evaluate_ppl, hf_check_ppl, hf_lm_eval from aixsim_models.llm.utils import convert_torch_to_hf -from aixsim_models.bitflip.pretrainer import pretrain -from aixsim_models.bitflip.arg_manager import ArgRandomBitFlipTransform -from aixsim_models.bitflip.arg_manager import ( +from aixsim_models.bitflip.pretrain.pretrainer import pretrain +from aixsim_models.bitflip.pretrain.arg_manager import ArgRandomBitFlipTransform +from aixsim_models.bitflip.pretrain.arg_manager import ( ArgJob, ArgProfiling, ArgMetrics, @@ -32,7 +31,7 @@ ArgMemoryEstimation, PreTrainArgs, ) -from aixsim_models.bitflip.profiler import profile_stats_hf +from aixsim_models.bitflip.pretrain.profiler import profile_stats_hf register_model_configs() register_pretrain_dataset() @@ -97,7 +96,9 @@ def generate_pretrain_cfg( silent=True, ) num_tokens = token_num_scale * num_params - effective_batch_size = batch_size * data_parallel_replicate_degree * abs(data_parallel_shard_degree) + effective_batch_size = ( + batch_size * data_parallel_replicate_degree * abs(data_parallel_shard_degree) + ) num_steps = math.ceil(num_tokens / (effective_batch_size * seq_len)) print( @@ -106,7 +107,9 @@ def generate_pretrain_cfg( print(f"Effective batch size: {effective_batch_size}") print(f"Estimated number of steps: {num_steps}") - assert transform_config.exists(), f"Transform config file {transform_config} does not exist" + assert ( + transform_config.exists() + ), f"Transform config file {transform_config} does not exist" with open(transform_config, "r") as f: transform_config = yaml.safe_load(f) @@ -117,7 +120,9 @@ def generate_pretrain_cfg( ), profiling=ArgProfiling(), metrics=ArgMetrics(enable_tensorboard=False, enable_wandb=True), - model=ArgModel(name=model_arch, flavor=model_flavor, tokenizer_path=tokenizer_path), + model=ArgModel( + name=model_arch, flavor=model_flavor, tokenizer_path=tokenizer_path + ), optimizer=ArgOptimizer(lr=learning_rate), training=ArgTraining( dataset="fineweb-edu", @@ -169,9 +174,17 @@ def pt_eval_ppl( seq_len: int = 2048, ): from pprint import pformat - from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config + from torchtitan.models import ( + model_name_to_cls, + model_name_to_tokenizer, + models_config, + ) from aixsim_models.llm.tokenizer import build_tokenizer - from aixsim_models.bitflip.transform import transform_model, TransformConfigManager, make_transform_histogram + from aixsim_models.bitflip.transform import ( + transform_model, + TransformConfigManager, + make_transform_histogram, + ) transform_config_manager = TransformConfigManager( layer_name_to_config=transform_config.layer_name_to_config, diff --git a/experiments/llm-bitflip/transform/minimal.py b/experiments/llm-bitflip/transform/minimal.py index 0f6aa17..5a7bee2 100644 --- a/experiments/llm-bitflip/transform/minimal.py +++ b/experiments/llm-bitflip/transform/minimal.py @@ -13,7 +13,10 @@ from jsonargparse import CLI from aixsim_models.llm.evaluator import hf_lm_eval, hf_generate -from aixsim_models.bitflip.transform import transform_model, TransformConfigManager +from aixsim_models.bitflip.pretrain.transform import ( + transform_model, + TransformConfigManager, +) DEFAULT_DTYPE = "float16" DEFAULT_TASKS = ["wikitext"] @@ -31,7 +34,9 @@ def eval_ori( ): """Evaluate a pretrained model as baseline.""" device = torch.device("cuda") - model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=getattr(torch, dtype)).eval() + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=getattr(torch, dtype) + ).eval() model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/experiments/llm-optical-transformer/continual_finetuning/run_clm_no_trainer.py b/experiments/llm-optical-transformer/continual_finetuning/run_clm_no_trainer.py index 488731c..726419a 100644 --- a/experiments/llm-optical-transformer/continual_finetuning/run_clm_no_trainer.py +++ b/experiments/llm-optical-transformer/continual_finetuning/run_clm_no_trainer.py @@ -34,6 +34,7 @@ Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=text-generation """ + # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import argparse @@ -42,9 +43,12 @@ import math import os import random +import sys from itertools import chain from pathlib import Path +sys.path.append(Path(__file__).resolve().parents[3].joinpath("src").as_posix()) + import datasets import tomllib import torch @@ -395,6 +399,7 @@ def set_trainable(model: torch.nn.Module): logger.info( f"Number of trainable parameters: {n_params:,} ({100 * n_params / total_params:.2f}%)\nTrainable parameters: {trainable}" ) + else: def set_trainable(model: torch.nn.Module): @@ -709,9 +714,11 @@ def group_texts(examples): name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps - if overrode_max_train_steps - else args.max_train_steps * accelerator.num_processes, + num_training_steps=( + args.max_train_steps + if overrode_max_train_steps + else args.max_train_steps * accelerator.num_processes + ), ) # Prepare everything with our `accelerator`. diff --git a/experiments/llm-optical-transformer/continual_pretraining/run_clm.py b/experiments/llm-optical-transformer/continual_pretraining/run_clm.py index a38b425..6b03ec8 100644 --- a/experiments/llm-optical-transformer/continual_pretraining/run_clm.py +++ b/experiments/llm-optical-transformer/continual_pretraining/run_clm.py @@ -283,14 +283,18 @@ def __post_init__(self): else: if self.train_file is not None: extension = self.train_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], ( - "`train_file` should be a csv, a json or a txt file." - ) + assert extension in [ + "csv", + "json", + "txt", + ], "`train_file` should be a csv, a json or a txt file." if self.validation_file is not None: extension = self.validation_file.split(".")[-1] - assert extension in ["csv", "json", "txt"], ( - "`validation_file` should be a csv, a json or a txt file." - ) + assert extension in [ + "csv", + "json", + "txt", + ], "`validation_file` should be a csv, a json or a txt file." def main(): @@ -680,9 +684,11 @@ def compute_metrics(eval_preds): processing_class=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, - compute_metrics=compute_metrics - if training_args.do_eval and not is_torch_xla_available() - else None, + compute_metrics=( + compute_metrics + if training_args.do_eval and not is_torch_xla_available() + else None + ), preprocess_logits_for_metrics=( preprocess_logits_for_metrics if training_args.do_eval and not is_torch_xla_available() diff --git a/experiments/llm-optical-transformer/lora_finetuning/run_clm_no_trainer.py b/experiments/llm-optical-transformer/lora_finetuning/run_clm_no_trainer.py index 1ebfede..4a5bb8d 100644 --- a/experiments/llm-optical-transformer/lora_finetuning/run_clm_no_trainer.py +++ b/experiments/llm-optical-transformer/lora_finetuning/run_clm_no_trainer.py @@ -34,6 +34,7 @@ Here is the full list of checkpoints on the hub that can be fine-tuned by this script: https://huggingface.co/models?filter=text-generation """ + # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. import argparse @@ -42,9 +43,12 @@ import math import os import random +import sys from itertools import chain from pathlib import Path +sys.path.append(Path(__file__).resolve().parents[3].joinpath("src").as_posix()) + import datasets import tomllib import torch @@ -395,6 +399,7 @@ def set_trainable(model: torch.nn.Module): logger.info( f"Number of trainable parameters: {n_params:,} ({100 * n_params / total_params:.2f}%)\nTrainable parameters: {trainable}" ) + else: def set_trainable(model: torch.nn.Module): @@ -709,9 +714,11 @@ def group_texts(examples): name=args.lr_scheduler_type, optimizer=optimizer, num_warmup_steps=args.num_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps - if overrode_max_train_steps - else args.max_train_steps * accelerator.num_processes, + num_training_steps=( + args.max_train_steps + if overrode_max_train_steps + else args.max_train_steps * accelerator.num_processes + ), ) # Prepare everything with our `accelerator`. diff --git a/experiments/llm-optical-transformer/pretrain/run.py b/experiments/llm-optical-transformer/pretrain/run.py index 9c66783..9a3eca5 100644 --- a/experiments/llm-optical-transformer/pretrain/run.py +++ b/experiments/llm-optical-transformer/pretrain/run.py @@ -15,8 +15,10 @@ from aixsim_models.llm.evaluator import pt_evaluate_ppl, hf_check_ppl, hf_lm_eval from aixsim_models.llm.utils import convert_torch_to_hf, convert_hf_to_torch -from aixsim_models.optical_compute.optical_transformer.pretrainer import pretrain -from aixsim_models.optical_compute.optical_transformer.arg_manager import ( +from aixsim_models.optical_compute.optical_transformer.pretrain.pretrainer import ( + pretrain, +) +from aixsim_models.optical_compute.optical_transformer.pretrain.arg_manager import ( ArgJob, ArgProfiling, ArgMetrics, @@ -64,7 +66,9 @@ def generate_pretrain_cfg( silent=True, ) num_tokens = token_num_scale * num_params - effective_batch_size = batch_size * data_parallel_replicate_degree * abs(data_parallel_shard_degree) + effective_batch_size = ( + batch_size * data_parallel_replicate_degree * abs(data_parallel_shard_degree) + ) num_steps = math.ceil(num_tokens / (effective_batch_size * seq_len)) print( @@ -73,7 +77,9 @@ def generate_pretrain_cfg( print(f"Effective batch size: {effective_batch_size}") print(f"Estimated number of steps: {num_steps}") - assert transform_config.exists(), f"Transform config file {transform_config} does not exist" + assert ( + transform_config.exists() + ), f"Transform config file {transform_config} does not exist" with open(transform_config, "r") as f: transform_config = yaml.safe_load(f) @@ -84,7 +90,9 @@ def generate_pretrain_cfg( ), profiling=ArgProfiling(), metrics=ArgMetrics(enable_tensorboard=False, enable_wandb=True), - model=ArgModel(name=model_arch, flavor=model_flavor, tokenizer_path=tokenizer_path), + model=ArgModel( + name=model_arch, flavor=model_flavor, tokenizer_path=tokenizer_path + ), optimizer=ArgOptimizer(lr=learning_rate), training=ArgTraining( dataset="fineweb-edu", diff --git a/experiments/roberta-optical-transformer/run_glue.py b/experiments/roberta-optical-transformer/run_glue.py index f679da2..3f1edf5 100644 --- a/experiments/roberta-optical-transformer/run_glue.py +++ b/experiments/roberta-optical-transformer/run_glue.py @@ -21,8 +21,11 @@ import random import sys from dataclasses import dataclass, field +from pathlib import Path from typing import Optional +sys.path.append(Path(__file__).resolve().parents[2].joinpath("src").as_posix()) + import datasets import evaluate import numpy as np diff --git a/mkdocs.yml b/mkdocs.yml index f3abbab..67e49e8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -66,6 +66,7 @@ nav: - Model Behaviour Level Simulation: - Random BitFlip: - LLM: "02-model-behaviour-level-simulation/clm-bitflip.md" + - LoRA Fine-Tuning: "02-model-behaviour-level-simulation/clm-bitflip-lora-finetune.md" - Optical Neural Networks: - RoBERTa: "02-model-behaviour-level-simulation/roberta-onn.md" - CLM: "02-model-behaviour-level-simulation/clm-onn.md" diff --git a/requirements.txt b/requirements.txt index 7a2ece5..5e2ebc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ colorlog torchdata tensorboard lm-eval==0.4.7 -mase-tools @ git+https://github.com/DeepWok/mase@cz/mase-triton-bitflip \ No newline at end of file +mase-triton \ No newline at end of file diff --git a/src/aixsim_models/bitflip/fine_tune/__init__.py b/src/aixsim_models/bitflip/fine_tune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aixsim_models/bitflip/fine_tune/bitflip_llama.py b/src/aixsim_models/bitflip/fine_tune/bitflip_llama.py new file mode 100644 index 0000000..0ab2d80 --- /dev/null +++ b/src/aixsim_models/bitflip/fine_tune/bitflip_llama.py @@ -0,0 +1,45 @@ +import torch +from mase_triton.random_bitflip.layers import RandomBitFlipLinear +from torch import nn +from transformers.models.llama.modeling_llama import LlamaForCausalLM + +from ...utils.torch_module import set_layer_by_name +from .bitflip_lora import BitFlipLinearLora + + +def transform_llama( + model: LlamaForCausalLM, + fc_config: dict, + use_lora: bool, +) -> list[str]: + """Replace all Linear layers (except lm_head) with bitflip-aware layers. + + Args: + model: A LlamaForCausalLM model. + fc_config: Config dict passed to BitFlipLinearLora.from_linear or + RandomBitFlipLinear.from_linear. When use_lora is True, this should + include both bitflip params and lora params (r, lora_alpha). + use_lora: If True, use BitFlipLinearLora; otherwise use RandomBitFlipLinear. + + Returns: + List of replaced layer names. + """ + assert isinstance(model, LlamaForCausalLM) + replaced_layers = [] + + for name, layer in model.named_modules(): + if not isinstance(layer, nn.Linear): + continue + + if "lm_head" in name: + continue + + if use_lora: + new_layer = BitFlipLinearLora.from_linear(layer, **fc_config) + else: + new_layer = RandomBitFlipLinear.from_linear(layer, **fc_config) + + set_layer_by_name(model, name, new_layer) + replaced_layers.append(name) + + return replaced_layers diff --git a/src/aixsim_models/bitflip/fine_tune/bitflip_lora.py b/src/aixsim_models/bitflip/fine_tune/bitflip_lora.py new file mode 100644 index 0000000..ab1b704 --- /dev/null +++ b/src/aixsim_models/bitflip/fine_tune/bitflip_lora.py @@ -0,0 +1,160 @@ +import math + +import torch +from mase_triton.random_bitflip.core import random_bitflip_fn +from mase_triton.random_bitflip.layers import RandomBitFlipLinear +from torch import Tensor, nn + + +class BitFlipLinearLora(RandomBitFlipLinear): + """RandomBitFlipLinear with LoRA adaptation. + + Forward: Y = bitflip(X) @ bitflip(W + B @ A * scaling)^T + bias + Only lora_A and lora_B are trainable during fine-tuning. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + device, + dtype, + x_p_exp: float | None, + x_p_frac: float | None, + x_zero_out_t: float | None, + w_p_exp: float | None, + w_p_frac: float | None, + w_zero_out_t: float | None, + x_seed_exp: int = 0, + x_seed_frac: int = 0, + w_seed_exp: int = 0, + w_seed_frac: int = 0, + r: int = 32, + lora_alpha: int = 32, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device, + dtype, + x_p_exp=x_p_exp, + x_p_frac=x_p_frac, + x_zero_out_t=x_zero_out_t, + w_p_exp=w_p_exp, + w_p_frac=w_p_frac, + w_zero_out_t=w_zero_out_t, + x_seed_exp=x_seed_exp, + x_seed_frac=x_seed_frac, + w_seed_exp=w_seed_exp, + w_seed_frac=w_seed_frac, + ) + self.r = r + self.lora_alpha = lora_alpha + self.scaling = self.lora_alpha / self.r if r > 0 else 1 + + if r > 0: + self.lora_A = nn.Parameter( + torch.zeros((r, in_features), device=device, dtype=dtype) + ) + self.lora_B = nn.Parameter( + torch.zeros((out_features, r), device=device, dtype=dtype) + ) + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + else: + self.register_parameter("lora_A", None) + self.register_parameter("lora_B", None) + self.r: int + self.lora_A: nn.Parameter | None + self.lora_B: nn.Parameter | None + self.scaling: float + + def forward(self, x: Tensor) -> Tensor: + # 1. Apply input bitflip (if configured) + if not (self.x_p_exp is None and self.x_p_frac is None and self.x_zero_out_t is None): + x, x_seed_exp, x_seed_frac = random_bitflip_fn( + x, + exp_halves=self.x_nearest_exp_halves, + frac_halves=self.x_nearest_frac_halves, + seed_exp=self.x_seed_exp, + seed_frac=self.x_seed_frac, + zero_out_threshold=self.x_zero_out_t, + ) + self.x_seed_exp = x_seed_exp + self.x_seed_frac = x_seed_frac + + # 2. Compute adapted weight: W + B @ A * scaling + w = self.weight + if self.r > 0: + w = w + (self.lora_B @ self.lora_A) * self.scaling + + # 3. Apply weight bitflip + if self.w_p_exp is None and self.w_p_frac is None and self.w_zero_out_t is None: + pass + else: + w, w_seed_exp, w_seed_frac = random_bitflip_fn( + w, + exp_halves=self.w_nearest_exp_halves, + frac_halves=self.w_nearest_frac_halves, + seed_exp=self.w_seed_exp, + seed_frac=self.w_seed_frac, + zero_out_threshold=self.w_zero_out_t, + ) + self.w_seed_exp = w_seed_exp + self.w_seed_frac = w_seed_frac + + # 4. Linear transformation + return torch.nn.functional.linear(x, w, self.bias) + + @classmethod + def from_linear( + cls, + linear: torch.nn.Linear, + x_p_exp: float | None, + x_p_frac: float | None, + x_zero_out_t: float | None, + w_p_exp: float | None, + w_p_frac: float | None, + w_zero_out_t: float | None, + x_seed_exp: int = 0, + x_seed_frac: int = 0, + w_seed_exp: int = 0, + w_seed_frac: int = 0, + r: int = 32, + lora_alpha: int = 32, + ) -> "BitFlipLinearLora": + new_fc = cls( + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + linear.weight.dtype, + x_p_exp=x_p_exp, + x_p_frac=x_p_frac, + x_zero_out_t=x_zero_out_t, + w_p_exp=w_p_exp, + w_p_frac=w_p_frac, + w_zero_out_t=w_zero_out_t, + x_seed_exp=x_seed_exp, + x_seed_frac=x_seed_frac, + w_seed_exp=w_seed_exp, + w_seed_frac=w_seed_frac, + r=r, + lora_alpha=lora_alpha, + ) + with torch.no_grad(): + if linear.weight.device != torch.device("meta"): + new_fc.weight.copy_(linear.weight) + if linear.bias is not None: + new_fc.bias.copy_(linear.bias) + return new_fc + + def merge_lora(self) -> None: + if self.r > 0: + with torch.no_grad(): + self.weight += (self.lora_B @ self.lora_A) * self.scaling + self.lora_A = None + self.lora_B = None + self.r = 0 diff --git a/src/aixsim_models/bitflip/pretrain/__init__.py b/src/aixsim_models/bitflip/pretrain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aixsim_models/bitflip/arg_manager.py b/src/aixsim_models/bitflip/pretrain/arg_manager.py similarity index 98% rename from src/aixsim_models/bitflip/arg_manager.py rename to src/aixsim_models/bitflip/pretrain/arg_manager.py index 3551f34..7059279 100644 --- a/src/aixsim_models/bitflip/arg_manager.py +++ b/src/aixsim_models/bitflip/pretrain/arg_manager.py @@ -2,7 +2,7 @@ from typing import Literal -from ..llm.arg_manager import ( +from ...llm.arg_manager import ( ArgJob, ArgProfiling, ArgMetrics, diff --git a/src/aixsim_models/bitflip/pretrainer.py b/src/aixsim_models/bitflip/pretrain/pretrainer.py similarity index 97% rename from src/aixsim_models/bitflip/pretrainer.py rename to src/aixsim_models/bitflip/pretrain/pretrainer.py index 09fd7b5..ae204d1 100644 --- a/src/aixsim_models/bitflip/pretrainer.py +++ b/src/aixsim_models/bitflip/pretrain/pretrainer.py @@ -22,10 +22,10 @@ ) from torchtitan.utils import device_module, device_type -from ..llm.tokenizer import build_tokenizer -from ..llm.pretrainer import train_loop, build_meta_model, count_params -from ..utils.torch_module import TransformConfigManager -from ..utils.wandb_utils import wandb_update_config, wandb_extract_and_update_tags +from ...llm.tokenizer import build_tokenizer +from ...llm.pretrainer import train_loop, build_meta_model, count_params +from ...utils.torch_module import TransformConfigManager +from ...utils.wandb_utils import wandb_update_config, wandb_extract_and_update_tags from .transform import transform_model, make_transform_histogram from .arg_manager import ArgRandomBitFlipTransform from .arg_manager import ( diff --git a/src/aixsim_models/bitflip/profiler.py b/src/aixsim_models/bitflip/pretrain/profiler.py similarity index 100% rename from src/aixsim_models/bitflip/profiler.py rename to src/aixsim_models/bitflip/pretrain/profiler.py diff --git a/src/aixsim_models/bitflip/pretrain/transform.py b/src/aixsim_models/bitflip/pretrain/transform.py new file mode 100644 index 0000000..b9a7f26 --- /dev/null +++ b/src/aixsim_models/bitflip/pretrain/transform.py @@ -0,0 +1,67 @@ +from typing import Literal, Optional +import logging + +import torch +from ...utils.torch_module import TransformConfigManager +from ...utils.deps import all_packages_are_available +from ...utils.torch_module import set_layer_by_name + +logger = logging.getLogger(__name__) + + +if not all_packages_are_available(("mase_triton",)): + + def transform_model(*args, **kwargs): + raise ImportError("mase-triton or chop not installed. Please install mase-triton to use this feature.") + + def make_transform_histogram(*args, **kwargs): + raise ImportError("mase-triton or chop not installed. Please install mase-triton to use this feature.") + +else: + + from mase_triton.random_bitflip.layers import RandomBitFlipDropout, RandomBitFlipLinear + + + def flip_bits_in_linear(model: torch.nn.Module, config_manager: TransformConfigManager) -> list[tuple[str, str]]: + replaced_layers = [] + + for name, layer in model.named_modules(): + if isinstance(layer, torch.nn.Linear): + layer_cfg = config_manager.get_layer_config(name) + if layer_cfg is None: + continue + new_layer = RandomBitFlipLinear.from_linear(layer, **layer_cfg) + set_layer_by_name(model, name, new_layer) + replaced_layers.append((name, config_manager.get_layer_config_entry(name))) + return replaced_layers + + + def transform_model( + model: torch.nn.Module, config_manager: TransformConfigManager, transform_flavor: Literal["fc"] + ) -> list[tuple[str, str]]: + """ + Transform a model into the random bitflip form using the given configuration manager and transform flavor. + + Args: + model (torch.nn.Module): The model to transform. + config_manager (TransformConfigManager): The configuration manager for the transformation. + transform_flavor (Literal["fc"]): The flavor of the transformation to apply. + + Returns: + list[tuple[str, str]]: A list of tuples containing the names of the layers that were replaced and the configuration + entry that was used for the replacement + """ + if transform_flavor == "fc": + return flip_bits_in_linear(model, config_manager) + else: + raise ValueError(f"Unknown transform flavor {transform_flavor}") + + + def make_transform_histogram(replaced_layers: list[tuple[str, str]]) -> dict[str, dict[str, int | list[str]]]: + patterns = set(layer[1] for layer in replaced_layers) + histogram = {pattern: {"count": 0, "layers": []} for pattern in patterns} + for layer, pattern in replaced_layers: + histogram[pattern]["count"] += 1 + histogram[pattern]["layers"].append(layer) + histogram["total"] = {"layer count": len(replaced_layers), "pattern count": len(patterns)} + return histogram diff --git a/src/aixsim_models/bitflip/transform.py b/src/aixsim_models/bitflip/transform.py deleted file mode 100644 index b022eab..0000000 --- a/src/aixsim_models/bitflip/transform.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Literal, Optional -import logging - -import torch -from chop.passes.module.transforms.bitflip import bitflip_module_transform_pass -from ..utils.torch_module import TransformConfigManager -from ..utils.deps import all_packages_are_available - -logger = logging.getLogger(__name__) - - -if not all_packages_are_available(("mase_triton", "chop")): - - def transform_model(*args, **kwargs): - raise ImportError("mase-triton or chop not installed. Please install mase-triton to use this feature.") - - def make_transform_histogram(*args, **kwargs): - raise ImportError("mase-triton or chop not installed. Please install mase-triton to use this feature.") - -else: - - def transform_model( - model: torch.nn.Module, config_manager: TransformConfigManager, transform_flavor: Optional[Literal["fc"]] = None - ) -> None: - """ - Transform a model into the random bitflip form using the given configuration manager and transform flavor. - - Args: - model (torch.nn.Module): The model to transform. - config_manager (TransformConfigManager): The configuration manager for the transformation. - transform_flavor (Optional[Literal["fc"]]): The flavor of the transformation. Defaults to None. - - Returns: - torch.nn.Module: The transformed model. - """ - - if transform_flavor is None or transform_flavor == "fc": - # *: use the bitflip transform pass in mase-tools - pass_args = config_manager.layer_name_to_config - pass_args = pass_args | {"by": "regex_name" if config_manager.use_regex else "name"} - bitflip_module_transform_pass(model, pass_args=pass_args) - else: - raise ValueError(f"Unknown transform flavor {transform_flavor}") - - def make_transform_histogram(replaced_layers: list[tuple[str, str]]) -> dict[str, dict[str, int | list[str]]]: - raise NotImplementedError("make_transform_histogram is not implemented.") - # patterns = set(layer[1] for layer in replaced_layers) - # histogram = {pattern: {"count": 0, "layers": []} for pattern in patterns} - # for layer, pattern in replaced_layers: - # histogram[pattern]["count"] += 1 - # histogram[pattern]["layers"].append(layer) - # histogram["total"] = {"layer count": len(replaced_layers), "pattern count": len(patterns)} - # return histogram diff --git a/src/aixsim_models/optical_compute/optical_transformer/pretrain/arg_manager.py b/src/aixsim_models/optical_compute/optical_transformer/pretrain/arg_manager.py index afb9bfd..4f6c411 100644 --- a/src/aixsim_models/optical_compute/optical_transformer/pretrain/arg_manager.py +++ b/src/aixsim_models/optical_compute/optical_transformer/pretrain/arg_manager.py @@ -2,7 +2,7 @@ from typing import Literal -from ...llm.arg_manager import ( +from ....llm.arg_manager import ( ArgJob, ArgProfiling, ArgMetrics, @@ -54,8 +54,8 @@ class PreTrainArgs: Communications library settings. memory_estimation : ArgMemoryEstimation Memory estimation settings. - transform: ArgRandomBitFlipTransform - Random bitflip transformation. + transform: ArgOpticalTransformerTransform + Optical transformer transformation. """ job: ArgJob = field(default_factory=ArgJob) diff --git a/src/aixsim_models/optical_compute/optical_transformer/pretrain/pretrainer.py b/src/aixsim_models/optical_compute/optical_transformer/pretrain/pretrainer.py index 46091b1..6a1ebae 100644 --- a/src/aixsim_models/optical_compute/optical_transformer/pretrain/pretrainer.py +++ b/src/aixsim_models/optical_compute/optical_transformer/pretrain/pretrainer.py @@ -25,10 +25,10 @@ ) from torchtitan.utils import device_module, device_type -from ...llm.tokenizer import build_tokenizer -from ...llm.pretrainer import train_loop, build_meta_model, count_params -from ...utils.torch_module import TransformConfigManager -from ...utils.wandb_utils import wandb_update_config, wandb_extract_and_update_tags +from ....llm.tokenizer import build_tokenizer +from ....llm.pretrainer import train_loop, build_meta_model, count_params +from ....utils.torch_module import TransformConfigManager +from ....utils.wandb_utils import wandb_update_config, wandb_extract_and_update_tags from .transform import transform_torchtitan_model, make_transform_histogram from .arg_manager import ( ArgJob, diff --git a/src/aixsim_models/optical_compute/optical_transformer/pretrain/transform.py b/src/aixsim_models/optical_compute/optical_transformer/pretrain/transform.py index 2abcfa1..c548ea5 100644 --- a/src/aixsim_models/optical_compute/optical_transformer/pretrain/transform.py +++ b/src/aixsim_models/optical_compute/optical_transformer/pretrain/transform.py @@ -5,9 +5,9 @@ from torchtitan.models.llama.model import Transformer as TTLlamaTransformer from transformers import LlamaForCausalLM as HFLlamaForCausalLM from mase_triton.optical_compute.layers import OpticalTransformerLinear -from ...utils.torch_module import set_layer_by_name, get_layer_name -from ...utils.deps import all_packages_are_available -from ...utils.torch_module import TransformConfigManager +from ....utils.torch_module import set_layer_by_name, get_layer_name +from ....utils.deps import all_packages_are_available +from ....utils.torch_module import TransformConfigManager from .layers import TTOpticalTransformerLlamaAttention, HFOpticalTransformerLlamaAttention if not all_packages_are_available(("mase_triton",)):