diff --git a/tpu_inference/utils.py b/tpu_inference/utils.py index 15cd1fe98..b147bbeea 100644 --- a/tpu_inference/utils.py +++ b/tpu_inference/utils.py @@ -28,9 +28,9 @@ # Map vllm dtype string that doesn't exactly match jax dtype string name. _VLLM_DTYPE_STR_TO_JAX_DTYPE = { - "fp8": jnp.float8_e4m3fn, - "fp8_e4m3": jnp.float8_e4m3fn, - "fp8_e5m2": jnp.float8_e5m2, + "fp8": jnp.float8_e4m3fn.dtype, + "fp8_e4m3": jnp.float8_e4m3fn.dtype, + "fp8_e5m2": jnp.float8_e5m2.dtype, }