Skip to content

Commit 5dfe210

Browse files
committed
Change sampling params to configrable
Signed-off-by: dannawang <dannawang@google.com>
1 parent 2625387 commit 5dfe210

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

examples/offload/offline_inference_kv_cache_verification.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def create_parser():
3838
parser.set_defaults(model="meta-llama/Llama-3.1-8B")
3939
parser.set_defaults(max_model_len=1024)
4040

41+
# Add sampling params
42+
sampling_group = parser.add_argument_group("Sampling parameters")
43+
sampling_group.add_argument("--max-tokens", type=int)
44+
sampling_group.add_argument("--top-p", type=float)
45+
sampling_group.add_argument("--top-k", type=int)
4146
return parser
4247

4348

@@ -46,14 +51,24 @@ def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]:
4651
Initializes a vLLM engine and sampling parameters from the given args.
4752
"""
4853
args_copy = copy.deepcopy(llm_args)
54+
# Pop arguments not used by LLM
55+
max_tokens = args_copy.pop("max_tokens")
56+
top_p = args_copy.pop("top_p")
57+
top_k = args_copy.pop("top_k")
58+
4959
# Create an LLM. The --seed argument is passed in via **args.
5060
llm = LLM(**args_copy)
5161

52-
# Create a sampling params
53-
sampling_params = SamplingParams(temperature=0,
54-
max_tokens=20,
55-
seed=42,
56-
ignore_eos=True)
62+
# Create a sampling params object
63+
sampling_params = llm.get_default_sampling_params()
64+
sampling_params.temperature = 0
65+
sampling_params.ignore_eos = True
66+
if max_tokens is not None:
67+
sampling_params.max_tokens = max_tokens
68+
if top_p is not None:
69+
sampling_params.top_p = top_p
70+
if top_k is not None:
71+
sampling_params.top_k = top_k
5772

5873
return llm, sampling_params
5974

0 commit comments

Comments
 (0)