Skip to content

Commit 2625387

Browse files
committed
fix gke kv cache verification with sampling_param.temperature=0
Signed-off-by: Juncheng Gu <jcgu@google.com>
1 parent e8bc4ac commit 2625387

File tree

2 files changed

+7
-24
lines changed

2 files changed

+7
-24
lines changed

examples/offload/offline_inference_kv_cache_verification.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,6 @@ def create_parser():
3838
parser.set_defaults(model="meta-llama/Llama-3.1-8B")
3939
parser.set_defaults(max_model_len=1024)
4040

41-
# Add sampling params
42-
sampling_group = parser.add_argument_group("Sampling parameters")
43-
sampling_group.add_argument("--max-tokens", type=int)
44-
sampling_group.add_argument("--temperature", type=float)
45-
sampling_group.add_argument("--top-p", type=float)
46-
sampling_group.add_argument("--top-k", type=int)
4741
return parser
4842

4943

@@ -52,25 +46,14 @@ def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]:
5246
Initializes a vLLM engine and sampling parameters from the given args.
5347
"""
5448
args_copy = copy.deepcopy(llm_args)
55-
# Pop arguments not used by LLM
56-
max_tokens = args_copy.pop("max_tokens")
57-
temperature = args_copy.pop("temperature")
58-
top_p = args_copy.pop("top_p")
59-
top_k = args_copy.pop("top_k")
60-
6149
# Create an LLM. The --seed argument is passed in via **args.
6250
llm = LLM(**args_copy)
6351

64-
# Create a sampling params object
65-
sampling_params = llm.get_default_sampling_params()
66-
if max_tokens is not None:
67-
sampling_params.max_tokens = max_tokens
68-
if temperature is not None:
69-
sampling_params.temperature = temperature
70-
if top_p is not None:
71-
sampling_params.top_p = top_p
72-
if top_k is not None:
73-
sampling_params.top_k = top_k
52+
# Create a sampling params
53+
sampling_params = SamplingParams(temperature=0,
54+
max_tokens=20,
55+
seed=42,
56+
ignore_eos=True)
7457

7558
return llm, sampling_params
7659

tests/distributed/offload/tpu_offload_connector_worker_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
logger = init_logger(__name__)
2828

29-
_DEFAULT_BLOCK_SIZE = 256
29+
_DEFAULT_BLOCK_SIZE = 64
3030

3131

3232
class MockTPUModelRunner(TPUModelRunner):
@@ -97,7 +97,7 @@ def tearDown(self):
9797
super().tearDown()
9898
# Destroy references explicitly
9999
if hasattr(self, 'connector'):
100-
del self.connector
100+
del self.connector
101101

102102
# Force JAX to release memory
103103
cc.reset_cache()

0 commit comments

Comments
 (0)