1616from tpu_inference .utils import to_jax_dtype , to_torch_dtype
1717
1818if TYPE_CHECKING :
19- from vllm .attention .backends .registry import _Backend
19+ from vllm .attention .backends .registry import AttentionBackendEnum
2020 from vllm .config import BlockSize , ModelConfig , VllmConfig
2121 from vllm .pooling_params import PoolingParams
2222else :
2323 BlockSize = None
2424 ModelConfig = None
2525 VllmConfig = None
2626 PoolingParams = None
27- _Backend = None
27+ AttentionBackendEnum = None
2828
2929logger = init_logger (__name__ )
3030
@@ -48,13 +48,13 @@ class TpuPlatform(Platform):
4848 ]
4949
5050 @classmethod
51- def get_attn_backend_cls (cls , selected_backend : "_Backend" , head_size : int ,
52- dtype : jnp . dtype , kv_cache_dtype : Optional [ str ] ,
53- block_size : int , use_v1 : bool , use_mla : bool ,
54- has_sink : bool , use_sparse : bool ,
55- attn_type : Any ) -> str :
56- from vllm .attention .backends .registry import _Backend
57- if selected_backend != _Backend .PALLAS :
51+ def get_attn_backend_cls (cls , selected_backend : "AttentionBackendEnum" ,
52+ head_size : int , dtype : jnp . dtype ,
53+ kv_cache_dtype : Optional [ str ], block_size : int ,
54+ use_v1 : bool , use_mla : bool , has_sink : bool ,
55+ use_sparse : bool , attn_type : Any ) -> str :
56+ from vllm .attention .backends .registry import AttentionBackendEnum
57+ if selected_backend != AttentionBackendEnum .PALLAS :
5858 logger .info ("Cannot use %s backend on TPU." , selected_backend )
5959
6060 if use_v1 :
0 commit comments