@@ -38,6 +38,11 @@ def create_parser():
3838 parser .set_defaults (model = "meta-llama/Llama-3.1-8B" )
3939 parser .set_defaults (max_model_len = 1024 )
4040
41+ # Add sampling params
42+ sampling_group = parser .add_argument_group ("Sampling parameters" )
43+ sampling_group .add_argument ("--max-tokens" , type = int )
44+ sampling_group .add_argument ("--top-p" , type = float )
45+ sampling_group .add_argument ("--top-k" , type = int )
4146 return parser
4247
4348
@@ -46,14 +51,24 @@ def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]:
4651 Initializes a vLLM engine and sampling parameters from the given args.
4752 """
4853 args_copy = copy .deepcopy (llm_args )
54+ # Pop arguments not used by LLM
55+ max_tokens = args_copy .pop ("max_tokens" )
56+ top_p = args_copy .pop ("top_p" )
57+ top_k = args_copy .pop ("top_k" )
58+
4959 # Create an LLM. The --seed argument is passed in via **args.
5060 llm = LLM (** args_copy )
5161
52- # Create a sampling params
53- sampling_params = SamplingParams (temperature = 0 ,
54- max_tokens = 20 ,
55- seed = 42 ,
56- ignore_eos = True )
62+ # Create a sampling params object
63+ sampling_params = llm .get_default_sampling_params ()
64+ sampling_params .temperature = 0
65+ sampling_params .ignore_eos = True
66+ if max_tokens is not None :
67+ sampling_params .max_tokens = max_tokens
68+ if top_p is not None :
69+ sampling_params .top_p = top_p
70+ if top_k is not None :
71+ sampling_params .top_k = top_k
5772
5873 return llm , sampling_params
5974
0 commit comments