We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 05e4b16 commit 0333450Copy full SHA for 0333450
tpu_inference/utils.py
@@ -28,9 +28,9 @@
28
29
# Map vllm dtype string that doesn't exactly match jax dtype string name.
30
_VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
- "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3fn,
33
- "fp8_e5m2": jnp.float8_e5m2,
+ "fp8": jnp.float8_e4m3fn.dtype,
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
34
}
35
36
0 commit comments