Skip to content

Commit 4db50f5

Browse files
authored
[Misc] Update Attention backend registry (#1215)
1 parent 6c33a64 commit 4db50f5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
from tpu_inference.utils import to_jax_dtype, to_torch_dtype
1717

1818
if TYPE_CHECKING:
19-
from vllm.attention.backends.registry import _Backend
19+
from vllm.attention.backends.registry import AttentionBackendEnum
2020
from vllm.config import BlockSize, ModelConfig, VllmConfig
2121
from vllm.pooling_params import PoolingParams
2222
else:
2323
BlockSize = None
2424
ModelConfig = None
2525
VllmConfig = None
2626
PoolingParams = None
27-
_Backend = None
27+
AttentionBackendEnum = None
2828

2929
logger = init_logger(__name__)
3030

@@ -48,13 +48,13 @@ class TpuPlatform(Platform):
4848
]
4949

5050
@classmethod
51-
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
52-
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
53-
block_size: int, use_v1: bool, use_mla: bool,
54-
has_sink: bool, use_sparse: bool,
55-
attn_type: Any) -> str:
56-
from vllm.attention.backends.registry import _Backend
57-
if selected_backend != _Backend.PALLAS:
51+
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
52+
head_size: int, dtype: jnp.dtype,
53+
kv_cache_dtype: Optional[str], block_size: int,
54+
use_v1: bool, use_mla: bool, has_sink: bool,
55+
use_sparse: bool, attn_type: Any) -> str:
56+
from vllm.attention.backends.registry import AttentionBackendEnum
57+
if selected_backend != AttentionBackendEnum.PALLAS:
5858
logger.info("Cannot use %s backend on TPU.", selected_backend)
5959

6060
if use_v1:

0 commit comments

Comments
 (0)