Skip to content

Commit ab6985e

Browse files
wip
1 parent 05e4b16 commit ab6985e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tpu_inference/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
# Map vllm dtype string that doesn't exactly match jax dtype string name.
3030
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31-
"fp8": jnp.float8_e4m3fn,
32-
"fp8_e4m3": jnp.float8_e4m3fn,
33-
"fp8_e5m2": jnp.float8_e5m2,
31+
"fp8": jnp.float8_e4m3fn.dtype,
32+
"fp8_e4m3": jnp.float8_e4m3fn.dtype,
33+
"fp8_e5m2": jnp.float8_e5m2.dtype,
3434
}
3535

3636

0 commit comments

Comments
 (0)