|
| 1 | +# Intel® Extension for PyTorch\* Large Language Model (LLM) Feature Get Started For Llama 3.1 models |
| 2 | + |
| 3 | +Intel® Extension for PyTorch\* provides dedicated optimization for running Llama 3.1 models faster, including technical points like paged attention, ROPE fusion, etc. And a set of data types are supported for various scenarios, including BF16, Weight Only Quantization INT8/INT4 (prototype), etc. |
| 4 | + |
| 5 | +# 1. Environment Setup |
| 6 | + |
| 7 | +There are several environment setup methodologies provided. You can choose either of them according to your usage scenario. The Docker-based ones are recommended. |
| 8 | + |
| 9 | +## 1.1 [RECOMMENDED] Docker-based environment setup with pre-built wheels |
| 10 | + |
| 11 | +```bash |
| 12 | +# Get the Intel® Extension for PyTorch\* source code |
| 13 | +git clone https://github.com/intel/intel-extension-for-pytorch.git |
| 14 | +cd intel-extension-for-pytorch |
| 15 | +git checkout 2.4-llama-3 |
| 16 | +git submodule sync |
| 17 | +git submodule update --init --recursive |
| 18 | + |
| 19 | +# Build an image with the provided Dockerfile by installing from Intel® Extension for PyTorch\* prebuilt wheel files |
| 20 | +DOCKER_BUILDKIT=1 docker build -f examples/cpu/inference/python/llm/Dockerfile -t ipex-llm:2.4.0 . |
| 21 | + |
| 22 | +# Run the container with command below |
| 23 | +docker run --rm -it --privileged ipex-llm:2.4.0 bash |
| 24 | + |
| 25 | +# When the command prompt shows inside the docker container, enter llm examples directory |
| 26 | +cd llm |
| 27 | + |
| 28 | +# Activate environment variables |
| 29 | +source ./tools/env_activate.sh |
| 30 | +``` |
| 31 | + |
| 32 | +## 1.2 Conda-based environment setup with pre-built wheels |
| 33 | + |
| 34 | +```bash |
| 35 | +# Get the Intel® Extension for PyTorch\* source code |
| 36 | +git clone https://github.com/intel/intel-extension-for-pytorch.git |
| 37 | +cd intel-extension-for-pytorch |
| 38 | +git checkout 2.4-llama-3 |
| 39 | +git submodule sync |
| 40 | +git submodule update --init --recursive |
| 41 | + |
| 42 | +# Create a conda environment (pre-built wheel only available with python=3.10) |
| 43 | +conda create -n llm python=3.10 -y |
| 44 | +conda activate llm |
| 45 | + |
| 46 | +# Setup the environment with the provided script |
| 47 | +# A sample "prompt.json" file for benchmarking is also downloaded |
| 48 | +cd examples/cpu/inference/python/llm |
| 49 | +bash ./tools/env_setup.sh 7 |
| 50 | + |
| 51 | +# Activate environment variables |
| 52 | +source ./tools/env_activate.sh |
| 53 | +``` |
| 54 | +<br> |
| 55 | + |
| 56 | +# 2. How To Run Llama 3.1 with ipex.llm |
| 57 | + |
| 58 | +**ipex.llm provides a single script to facilitate running generation tasks as below:** |
| 59 | + |
| 60 | +``` |
| 61 | +# if you are using a docker container built from commands above in Sec. 1.1, the placeholder LLM_DIR below is ~/llm |
| 62 | +# if you are using a conda env created with commands above in Sec. 1.2, the placeholder LLM_DIR below is intel-extension-for-pytorch/examples/cpu/inference/python/llm |
| 63 | +cd <LLM_DIR> |
| 64 | +python run.py --help # for more detailed usages |
| 65 | +``` |
| 66 | + |
| 67 | +| Key args of run.py | Notes | |
| 68 | +|---|---| |
| 69 | +| model id | "--model-name-or-path" or "-m" to specify the <LLAMA3_MODEL_ID_OR_LOCAL_PATH>, it is model id from Huggingface or downloaded local path | |
| 70 | +| generation | default: beam search (beam size = 4), "--greedy" for greedy search | |
| 71 | +| input tokens | provide fixed sizes for input prompt size, use "--input-tokens" for <INPUT_LENGTH> in [1024, 2048, 4096, 8192, 32768, 130944]; if "--input-tokens" is not used, use "--prompt" to choose other strings as inputs| |
| 72 | +| output tokens | default: 32, use "--max-new-tokens" to choose any other size | |
| 73 | +| batch size | default: 1, use "--batch-size" to choose any other size | |
| 74 | +| token latency | enable "--token-latency" to print out the first or next token latency | |
| 75 | +| generation iterations | use "--num-iter" and "--num-warmup" to control the repeated iterations of generation, default: 100-iter/10-warmup | |
| 76 | +| streaming mode output | greedy search only (work with "--greedy"), use "--streaming" to enable the streaming generation output | |
| 77 | + |
| 78 | +*Note:* You may need to log in your HuggingFace account to access the model files. Please refer to [HuggingFace login](https://huggingface.co/docs/huggingface_hub/quick-start#login). |
| 79 | + |
| 80 | +## 2.1 Usage of running Llama 3.1 models |
| 81 | + |
| 82 | +The _\<LLAMA3_MODEL_ID_OR_LOCAL_PATH\>_ in the below commands specifies the Llama 3.1 model you will run, which can be found from [HuggingFace Models](https://huggingface.co/models). |
| 83 | + |
| 84 | +### 2.1.1 Run generation with multiple instances on multiple CPU numa nodes |
| 85 | + |
| 86 | +#### 2.1.1.1 Prepare: |
| 87 | + |
| 88 | +```bash |
| 89 | +unset KMP_AFFINITY |
| 90 | +``` |
| 91 | + |
| 92 | +In the DeepSpeed cases below, we recommend "--shard-model" to shard model weight sizes more even for better memory usage when running with DeepSpeed. |
| 93 | + |
| 94 | +If using "--shard-model", it will save a copy of the shard model weights file in the path of "--output-dir" (default path is "./saved_results" if not provided). |
| 95 | +If you have used "--shard-model" and generated such a shard model path (or your model weights files are already well sharded), in further repeated benchmarks, please remove "--shard-model", and replace "-m <LLAMA3_MODEL_ID_OR_LOCAL_PATH>" with "-m <shard model path>" to skip the repeated shard steps. |
| 96 | + |
| 97 | +Besides, the standalone shard model function/scripts are also provided in section 2.1.1.4, in case you would like to generate the shard model weights files in advance before running distributed inference. |
| 98 | + |
| 99 | +#### 2.1.1.2 BF16: |
| 100 | + |
| 101 | +- Command: |
| 102 | +```bash |
| 103 | +deepspeed --bind_cores_to_rank run.py --benchmark -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --dtype bfloat16 --ipex --greedy --input-tokens <INPUT_LENGTH> --autotp --shard-model |
| 104 | +``` |
| 105 | + |
| 106 | +#### 2.1.1.3 Weight-only quantization (INT8): |
| 107 | + |
| 108 | +By default, for weight-only quantization, we use quantization with [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) inference ("--quant-with-amp") to get peak performance and fair accuracy. |
| 109 | +For weight-only quantization with deepspeed, we quantize the model then run the benchmark. The quantized model won't be saved. |
| 110 | + |
| 111 | +- Command: |
| 112 | +```bash |
| 113 | +deepspeed --bind_cores_to_rank run.py --benchmark -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --ipex --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --greedy --input-tokens <INPUT_LENGTH> --autotp --shard-model --output-dir "saved_results" |
| 114 | +# Note: you can add "--group-size" to tune good accuracy, suggested range as one of [32, 64, 128, 256, 512]. |
| 115 | +``` |
| 116 | + |
| 117 | +#### 2.1.1.4 How to Shard Model weight files for Distributed Inference with DeepSpeed |
| 118 | + |
| 119 | +To save memory usage, we could shard the model weights files under the local path before we launch distributed tests with DeepSpeed. |
| 120 | + |
| 121 | +``` |
| 122 | +cd ./utils |
| 123 | +# general command: |
| 124 | +python create_shard_model.py -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --save-path ./local_llama3_model_shard |
| 125 | +# After sharding the model, using "-m ./local_llama3_model_shard" in later tests |
| 126 | +``` |
| 127 | + |
| 128 | +### 2.1.2 Run generation with one socket inference |
| 129 | +#### 2.1.2.1 BF16: |
| 130 | + |
| 131 | +- Command: |
| 132 | + |
| 133 | +```bash |
| 134 | +OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list> python run.py --benchmark -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --dtype bfloat16 --ipex --greedy --input-tokens <INPUT_LENGTH> |
| 135 | +``` |
| 136 | + |
| 137 | +#### 2.1.2.2 Weight-only quantization (INT8): |
| 138 | + |
| 139 | +By default, for weight-only quantization, we use quantization with [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) inference ("--quant-with-amp") to get peak performance and fair accuracy. |
| 140 | + |
| 141 | +- Command: |
| 142 | + |
| 143 | +```bash |
| 144 | +OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list> python run.py --benchmark -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --output-dir "saved_results" --greedy --input-tokens <INPUT_LENGTH> |
| 145 | +# Note: you can add "--group-size" to tune good accuracy, suggested range as one of [32, 64, 128, 256, 512]. |
| 146 | +``` |
| 147 | + |
| 148 | +#### 2.1.2.3 Weight-only quantization (INT4): |
| 149 | +You can use auto-round (part of INC) to generate INT4 WOQ model with following steps. |
| 150 | +- Environment installation: |
| 151 | +```bash |
| 152 | +pip install git+https://github.com/intel/auto-round.git@e24b9074af6cdb099e31c92eb81b7f5e9a4a244e |
| 153 | +git clone https://github.com/intel/auto-round.git |
| 154 | +git checkout e24b9074af6cdb099e31c92eb81b7f5e9a4a244e |
| 155 | +cd auto-round/examples/language-modeling |
| 156 | +``` |
| 157 | + |
| 158 | +- Command (quantize): |
| 159 | +```bash |
| 160 | +python3 main.py --model_name $model_name --device cpu --sym --nsamples 512 --iters 1000 --group_size 32 --deployment_device cpu --disable_eval --output_dir <INT4_MODEL_SAVE_PATH> |
| 161 | +``` |
| 162 | + |
| 163 | +- Command (benchmark): |
| 164 | +```bash |
| 165 | +cd <LLM_DIR> |
| 166 | +IPEX_WOQ_GEMM_LOOP_SCHEME=ACB OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list> python run.py --benchmark -m <LLAMA3_MODEL_ID_OR_LOCAL_PATH> --ipex-weight-only-quantization --weight-dtype INT4 --quant-with-amp --output-dir "saved_results" --greedy --input-tokens <INPUT_LENGTH> --cache-weight-for-large-batch --low-precision-checkpoint <INT4_MODEL_SAVE_PATH> |
| 167 | +``` |
| 168 | + |
| 169 | +#### 2.1.2.4 Notes: |
| 170 | + |
| 171 | +(1) [_numactl_](https://linux.die.net/man/8/numactl) is used to specify memory and cores of your hardware to get better performance. _\<node N\>_ specifies the [numa](https://en.wikipedia.org/wiki/Non-uniform_memory_access) node id (e.g., 0 to use the memory from the first numa node). _\<physical cores list\>_ specifies phsysical cores which you are using from the _\<node N\>_ numa node. You can use [_lscpu_](https://man7.org/linux/man-pages/man1/lscpu.1.html) command in Linux to check the numa node information. |
| 172 | + |
| 173 | +(2) For all quantization benchmarks, both quantization and inference stages will be triggered by default. For quantization stage, it will auto-generate the quantized model named "best_model.pt" in the "--output-dir" path, and for inference stage, it will launch the inference with the quantized model "best_model.pt". For inference-only benchmarks (avoid the repeating quantization stage), you can also reuse these quantized models for by adding "--quantized-model-path <output_dir + "best_model.pt">" . |
| 174 | + |
| 175 | + |
| 176 | +## Miscellaneous Tips |
| 177 | +Intel® Extension for PyTorch\* also provides dedicated optimization for many other Large Language Models (LLM), which cover a set of data types that are supported for various scenarios. For more details, please check this [Intel® Extension for PyTorch\* doc](https://github.com/intel/intel-extension-for-pytorch/blob/release/2.3/README.md). |
0 commit comments