Skip to content

Commit 5f7dc4e

Browse files
authored
Add a lora perf test (#1272)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
1 parent 06bff73 commit 5f7dc4e

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

.buildkite/features/LoRA_Torch.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ steps:
3131
queue: tpu_v6e_queue
3232
commands:
3333
- |
34-
buildkite-agent meta-data set "LoRA_Torch_PerformanceTest" "to be added"
34+
.buildkite/scripts/run_in_docker.sh \
35+
bash -c 'MODEL_IMPL_TYPE=vllm TPU_BACKEND_TYPE=jax python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_lora_perf.py'
3536
- label: "Record performance test result for LoRA_Torch"
3637
key: "record_LoRA_Torch_PerformanceTest"
3738
depends_on: "LoRA_Torch_PerformanceTest"

tests/lora/test_lora_perf.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
import time
3+
4+
import pytest
5+
import vllm
6+
from vllm.lora.request import LoRARequest
7+
8+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
9+
10+
11+
@pytest.mark.parametrize("tp", TP)
12+
def test_lora_performance(tp):
13+
prompt = "What is 1+1? \n"
14+
llm_without_lora = vllm.LLM(
15+
model="Qwen/Qwen2.5-3B-Instruct",
16+
max_model_len=256,
17+
max_num_batched_tokens=64,
18+
max_num_seqs=8,
19+
tensor_parallel_size=tp,
20+
)
21+
start_time = time.time()
22+
llm_without_lora.generate(
23+
prompt,
24+
sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
25+
)[0].outputs[0].text
26+
base_time = time.time() - start_time
27+
28+
del llm_without_lora
29+
# Waiting for TPUs to be released
30+
time.sleep(10)
31+
32+
llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
33+
max_model_len=256,
34+
max_num_batched_tokens=64,
35+
max_num_seqs=8,
36+
tensor_parallel_size=tp,
37+
enable_lora=True,
38+
max_loras=1,
39+
max_lora_rank=8)
40+
lora_request = LoRARequest(
41+
"lora_adapter_2", 2,
42+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
43+
start_time = time.time()
44+
llm_with_lora.generate(prompt,
45+
sampling_params=vllm.SamplingParams(max_tokens=16,
46+
temperature=0),
47+
lora_request=lora_request)[0].outputs[0].text
48+
lora_time = time.time() - start_time
49+
print(f"Base time: {base_time}, LoRA time: {lora_time}")
50+
assert (base_time /
51+
lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
52+
53+
del llm_with_lora

0 commit comments

Comments
 (0)