From 3ce980475f1d68327981f2cc2b5ab21420108a25 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 22 Jan 2026 10:49:14 +0000 Subject: [PATCH 01/13] add-choose-vit-backend --- .../common/basemodel/attention/__init__.py | 1 + .../common/basemodel/attention/base_att.py | 20 ++++++ .../basemodel/attention/create_utils.py | 33 ++++++++- lightllm/common/basemodel/attention/fa3/fp.py | 69 ++++++++++++++++++- .../common/basemodel/attention/triton/fp.py | 33 ++++++++- lightllm/common/basemodel/basemodel.py | 3 +- lightllm/models/qwen2_vl/qwen2_visual.py | 8 ++- lightllm/models/qwen3_vl/qwen3_visual.py | 6 +- .../visualserver/model_infer/model_rpc.py | 10 ++- 9 files changed, 175 insertions(+), 8 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 80df54549..d6fba6966 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -15,4 +15,5 @@ get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, + get_vit_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 859d97ca8..28fd5596e 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -37,6 +37,9 @@ def create_att_prefill_state(self) -> "BasePrefillAttState": def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") + def create_vit_att_state(self) -> "BaseVitAttState": + raise NotImplementedError("not impl") + def _find_layer_index( self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] ) -> int: @@ -115,3 +118,20 @@ def decode_att( alloc_func=torch.empty, ) -> torch.Tensor: pass + + +class BaseVitAttState(ABC): + + backend: BaseAttBackend = None + + @abstractmethod + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 19252cf13..40844c0e0 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -10,7 +10,7 @@ from .triton.int4kv import Int4kvTritonAttBackend from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend -from .fa3.fp import Fa3AttBackend +from .fa3.fp import Fa3AttBackend, Fa3ViTAttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend @@ -46,6 +46,13 @@ }, } +vit_data_type_to_backend = { + "None": { + "triton": TritonAttBackend, + "fa3": Fa3ViTAttBackend, + }, +} + def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -60,6 +67,7 @@ def _auto_select_backend( for backend_name in priority_list: if validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") + print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") return backend_map[llm_dtype][backend_name] # Fallback to triton without validation (should not happen) @@ -67,6 +75,25 @@ def _auto_select_backend( return backend_map[llm_dtype]["triton"] +def _auto_select_vit_backend(llm_dtype: str, priority_list: list = ["fa3", "triton"]) -> type: + """Auto-select the best available backend with validation for vit. + + Priority: FA3 > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + backend_map = vit_data_type_to_backend + + for backend_name in priority_list: + if validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for ViT") + print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") + return backend_map[llm_dtype][backend_name] + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded for vit, falling back to triton") + return backend_map[llm_dtype]["triton"] + + def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type @@ -105,3 +132,7 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) + + +def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "triton"]) -> BaseAttBackend: + return _auto_select_vit_backend(llm_dtype="None", priority_list=priority_list) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d9..7369f7dee 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache @@ -37,6 +37,14 @@ def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": return Fa3DecodeAttState(backend=self, infer_state=infer_state) +class Fa3ViTAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_vit_att_state(self) -> "Fa3VitAttState": + return Fa3VitAttState(backend=self) + + @dataclasses.dataclass class Fa3PrefillAttState(BasePrefillAttState): cu_seqlens_q: torch.Tensor = None @@ -241,3 +249,62 @@ def _normal_decode_att( sinks=sink_weight, ) return o + + +@dataclasses.dataclass +class Fa3VitAttState(BaseVitAttState): + + backend: "Fa3ViTAttBackend" + + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + self.backend: Fa3ViTAttBackend = self.backend # for typing + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3..c35ad362a 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState from typing import Optional @@ -11,6 +11,9 @@ def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": return TritonDecodeAttState(backend=self, infer_state=infer_state) + def create_vit_att_state(self, infer_state) -> "TritonDecodeAttState": + return TritonVitAttState(backend=self, infer_state=infer_state) + @dataclasses.dataclass class TritonPrefillAttState(BasePrefillAttState): @@ -273,3 +276,31 @@ def _normal_decode_stage3_att( b_seq_len=self.infer_state.b_seq_len, ) return o_tensor + + +@dataclasses.dataclass +class TritonVitAttState(BaseVitAttState): + def init_state(self): + pass + + def _vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + alloc_func=torch.empty, + ): + from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 26d51af3d..e97381884 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache -from .attention import get_prefill_att_backend_class, get_decode_att_backend_class +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class, get_vit_att_backend_class from .attention import BaseAttBackend logger = init_logger(__name__) @@ -119,7 +119,6 @@ def __init__(self, kvargs): self._init_custom() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() - self._init_att_backend() self._init_att_backend1() diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 334ffc844..7d53c326f 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -127,6 +127,7 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) + self.vit_att_backend = None def forward( self, @@ -143,7 +144,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + self.vit_att_backend.vit_att(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output @@ -234,6 +235,11 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return + def _init_vit_att(self, vit_att): + for blk in self.blocks: + blk.attn.vit_att_backend = vit_att + return + def load_model(self, weight_dir): processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index 00ad6c05a..dd119f2de 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -129,7 +129,6 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") - self.depth = depth self.out_hidden_size = out_hidden_size self.hidden_size = hidden_size @@ -182,6 +181,11 @@ def __init__( ) self._init_datatype() + def _init_vit_att(self, vit_att): + for blk in self.blocks: + blk.attn.vit_att_backend = vit_att + return + def _init_datatype(self): if isinstance(self.data_type, torch.dtype): return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..79434f5b9 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,9 +24,15 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.basemodel import BaseAttBackend, get_vit_att_backend_class +from lightllm.utils.dist_utils import set_global_rank class VisualModelRpcServer(rpyc.Service): + def _init_vit_att_backend(self): + self.vit_att_backend: BaseAttBackend = get_vit_att_backend_class(index=0)(model=self) + return + def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) import torch @@ -42,7 +48,7 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - + set_global_rank(kvargs["tp_rank_id"]) # 这里看后面怎么改 init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -81,6 +87,8 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") + self._init_vit_att_backend() + self.model._init_vit_att(self.vit_att_backend.create_vit_att_state()) self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) From 7ec0e7af63713a1527175119fd3cdc1827e458d0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 02:49:01 +0000 Subject: [PATCH 02/13] add vit_attention dirs. --- lightllm/common/basemodel/attention_vit/__init__.py | 0 lightllm/common/basemodel/attention_vit/fa3/__init__.py | 0 lightllm/common/basemodel/attention_vit/triton/__init__.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 lightllm/common/basemodel/attention_vit/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/fa3/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/triton/__init__.py diff --git a/lightllm/common/basemodel/attention_vit/__init__.py b/lightllm/common/basemodel/attention_vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/fa3/__init__.py b/lightllm/common/basemodel/attention_vit/fa3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/triton/__init__.py b/lightllm/common/basemodel/attention_vit/triton/__init__.py new file mode 100644 index 000000000..e69de29bb From e76b3cc9e5ad7e04b053ce35a2276aba58810feb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:01:12 +0000 Subject: [PATCH 03/13] fix. --- .../basemodel/attention_vit/base_att.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 lightllm/common/basemodel/attention_vit/base_att.py diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py new file mode 100644 index 000000000..1b51ae08a --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -0,0 +1,46 @@ +import torch +from abc import ABC, abstractmethod + + +class BaseVitAttBackend: + """ + 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self, model): + self.model = model + + def create_vit_att_state(self) -> "BaseVitAttState": + raise NotImplementedError("not impl") + + +class BaseVitAttState(ABC): + + backend: BaseVitAttBackend = None + + @abstractmethod + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") From f8537bb54e95eebc9b90a450632d6d9adc53b051 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:08:26 +0000 Subject: [PATCH 04/13] fix --- .../common/basemodel/attention_vit/fa3/fp.py | 68 +++++++++++++++++++ .../basemodel/attention_vit/triton/fp.py | 36 ++++++++++ 2 files changed, 104 insertions(+) create mode 100644 lightllm/common/basemodel/attention_vit/fa3/fp.py create mode 100644 lightllm/common/basemodel/attention_vit/triton/fp.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py new file mode 100644 index 000000000..dfd437581 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -0,0 +1,68 @@ +import dataclasses +import torch +from ..base_att import BaseVitAttState, BaseVitAttBackend +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + + +class Fa3VitAttBackend(BaseVitAttBackend): + def __init__(self, model): + super().__init__(model=model) + + def create_vit_att_state(self) -> "Fa3VitAttState": + return Fa3VitAttState(backend=self) + + +@dataclasses.dataclass +class Fa3VitAttState(BaseVitAttState): + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + self.backend: Fa3VitAttBackend = self.backend # for typing + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py new file mode 100644 index 000000000..506313a8d --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -0,0 +1,36 @@ +import dataclasses +import torch +from ..base_att import BaseVitAttBackend, BaseVitAttState + + +class TritonVitAttBackend(BaseVitAttBackend): + def create_vit_att_state(self, infer_state) -> "TritonVitAttState": + return TritonVitAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class TritonVitAttState(BaseVitAttState): + def init_state(self): + pass + + def _vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + alloc_func=torch.empty, + ): + from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return From 65549a746da24a0df31c78cef3cf3b380f5b76b5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:10:33 +0000 Subject: [PATCH 05/13] fix --- .../common/basemodel/attention/__init__.py | 1 - .../common/basemodel/attention/base_att.py | 20 ------ .../basemodel/attention/create_utils.py | 33 +-------- lightllm/common/basemodel/attention/fa3/fp.py | 69 +------------------ .../common/basemodel/attention/triton/fp.py | 33 +-------- 5 files changed, 3 insertions(+), 153 deletions(-) diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index d6fba6966..80df54549 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -15,5 +15,4 @@ get_decode_att_backend_class, get_mla_prefill_att_backend_class, get_mla_decode_att_backend_class, - get_vit_att_backend_class, ) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 28fd5596e..859d97ca8 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -37,9 +37,6 @@ def create_att_prefill_state(self) -> "BasePrefillAttState": def create_att_decode_state(self) -> "BaseDecodeAttState": raise NotImplementedError("not impl") - def create_vit_att_state(self) -> "BaseVitAttState": - raise NotImplementedError("not impl") - def _find_layer_index( self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] ) -> int: @@ -118,20 +115,3 @@ def decode_att( alloc_func=torch.empty, ) -> torch.Tensor: pass - - -class BaseVitAttState(ABC): - - backend: BaseAttBackend = None - - @abstractmethod - def vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - ) -> torch.Tensor: - raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 40844c0e0..19252cf13 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -10,7 +10,7 @@ from .triton.int4kv import Int4kvTritonAttBackend from .triton.int8kv import Int8kvTritonAttBackend from .triton.mla import MlaTritonAttBackend -from .fa3.fp import Fa3AttBackend, Fa3ViTAttBackend +from .fa3.fp import Fa3AttBackend from .fa3.fp8 import Fp8Fa3AttBackend from .fa3.mla import MlaFa3AttBackend from .flashinfer.fp8 import Fp8FlashInferAttBackend @@ -46,13 +46,6 @@ }, } -vit_data_type_to_backend = { - "None": { - "triton": TritonAttBackend, - "fa3": Fa3ViTAttBackend, - }, -} - def _auto_select_backend( llm_dtype: str, is_mla: bool = False, priority_list: list = ["fa3", "flashinfer", "triton"] @@ -67,7 +60,6 @@ def _auto_select_backend( for backend_name in priority_list: if validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") - print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") return backend_map[llm_dtype][backend_name] # Fallback to triton without validation (should not happen) @@ -75,25 +67,6 @@ def _auto_select_backend( return backend_map[llm_dtype]["triton"] -def _auto_select_vit_backend(llm_dtype: str, priority_list: list = ["fa3", "triton"]) -> type: - """Auto-select the best available backend with validation for vit. - - Priority: FA3 > Triton - Each backend is validated in a subprocess with ground truth checks. - """ - backend_map = vit_data_type_to_backend - - for backend_name in priority_list: - if validate(backend_name): - logger.info(f"Auto-selected {backend_name} backend (validated) for ViT") - print(f"llm_dtype is {llm_dtype}, backend_name is {backend_name} ") - return backend_map[llm_dtype][backend_name] - - # Fallback to triton without validation (should not happen) - logger.warning("No backend validation succeeded for vit, falling back to triton") - return backend_map[llm_dtype]["triton"] - - def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend: args = get_env_start_args() llm_dtype = args.llm_kv_type @@ -132,7 +105,3 @@ def get_mla_decode_att_backend_class(index=0, priority_list: list = ["fa3", "fla return mla_data_type_to_backend[llm_dtype][backend_str] else: return _auto_select_backend(llm_dtype, is_mla=True, priority_list=priority_list) - - -def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "triton"]) -> BaseAttBackend: - return _auto_select_vit_backend(llm_dtype="None", priority_list=priority_list) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 7369f7dee..952bb39d9 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache @@ -37,14 +37,6 @@ def create_att_decode_state(self, infer_state) -> "Fa3DecodeAttState": return Fa3DecodeAttState(backend=self, infer_state=infer_state) -class Fa3ViTAttBackend(BaseAttBackend): - def __init__(self, model): - super().__init__(model=model) - - def create_vit_att_state(self) -> "Fa3VitAttState": - return Fa3VitAttState(backend=self) - - @dataclasses.dataclass class Fa3PrefillAttState(BasePrefillAttState): cu_seqlens_q: torch.Tensor = None @@ -249,62 +241,3 @@ def _normal_decode_att( sinks=sink_weight, ) return o - - -@dataclasses.dataclass -class Fa3VitAttState(BaseVitAttState): - - backend: "Fa3ViTAttBackend" - - def vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - ) -> None: - self.backend: Fa3ViTAttBackend = self.backend # for typing - - head_dim = q.shape[-1] - softmax_scale = head_dim ** -0.5 - window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( - q, - k, - v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - 0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, - ) - - return o diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index c35ad362a..d29f15ec3 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -1,6 +1,6 @@ import dataclasses import torch -from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl, BaseVitAttState +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl from typing import Optional @@ -11,9 +11,6 @@ def create_att_prefill_state(self, infer_state) -> "TritonPrefillAttState": def create_att_decode_state(self, infer_state) -> "TritonDecodeAttState": return TritonDecodeAttState(backend=self, infer_state=infer_state) - def create_vit_att_state(self, infer_state) -> "TritonDecodeAttState": - return TritonVitAttState(backend=self, infer_state=infer_state) - @dataclasses.dataclass class TritonPrefillAttState(BasePrefillAttState): @@ -276,31 +273,3 @@ def _normal_decode_stage3_att( b_seq_len=self.infer_state.b_seq_len, ) return o_tensor - - -@dataclasses.dataclass -class TritonVitAttState(BaseVitAttState): - def init_state(self): - pass - - def _vit_att( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - o: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: int, - alloc_func=torch.empty, - ): - from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd - - _flash_attention_triton_fwd( - q, - k, - v, - o, - cu_seqlens, # q k v cu_seqlens, - max_seqlen, - ) - return From e35427c6436d61ea1a8f22494846b366bdaf8b36 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 03:26:39 +0000 Subject: [PATCH 06/13] fix --- lightllm/common/basemodel/attention_vit/base_att.py | 10 +--------- lightllm/common/basemodel/attention_vit/fa3/fp.py | 9 +-------- lightllm/common/basemodel/attention_vit/triton/fp.py | 12 +----------- 3 files changed, 3 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index 1b51ae08a..e43475ac0 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -class BaseVitAttBackend: +class BaseVitAttBackend(ABC): """ 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, 这个是单列模式, 每种backend只有一个实例 @@ -25,14 +25,6 @@ def __new__(cls, *args, **kwargs): def __init__(self, model): self.model = model - def create_vit_att_state(self) -> "BaseVitAttState": - raise NotImplementedError("not impl") - - -class BaseVitAttState(ABC): - - backend: BaseVitAttBackend = None - @abstractmethod def vit_att( self, diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index dfd437581..fa7f14adb 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,19 +1,12 @@ import dataclasses import torch -from ..base_att import BaseVitAttState, BaseVitAttBackend -from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from ..base_att import BaseVitAttBackend class Fa3VitAttBackend(BaseVitAttBackend): def __init__(self, model): super().__init__(model=model) - def create_vit_att_state(self) -> "Fa3VitAttState": - return Fa3VitAttState(backend=self) - - -@dataclasses.dataclass -class Fa3VitAttState(BaseVitAttState): def vit_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 506313a8d..2c012f5a2 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -1,18 +1,8 @@ -import dataclasses import torch -from ..base_att import BaseVitAttBackend, BaseVitAttState +from ..base_att import BaseVitAttBackend class TritonVitAttBackend(BaseVitAttBackend): - def create_vit_att_state(self, infer_state) -> "TritonVitAttState": - return TritonVitAttState(backend=self, infer_state=infer_state) - - -@dataclasses.dataclass -class TritonVitAttState(BaseVitAttState): - def init_state(self): - pass - def _vit_att( self, q: torch.Tensor, From c18dd256f752e5d359191b142a9fd030848047c6 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 04:38:01 +0000 Subject: [PATCH 07/13] fix --- lightllm/common/basemodel/attention_vit/fa3/fp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index fa7f14adb..ed1be1400 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -16,7 +16,6 @@ def vit_att( cu_seqlens: torch.Tensor, max_seqlen: int, ) -> None: - self.backend: Fa3VitAttBackend = self.backend # for typing head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 From ad46e4aa63ac5fb78f4818a7d78aa67f3bd42825 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 23 Jan 2026 06:02:03 +0000 Subject: [PATCH 08/13] fix --- lightllm/common/basemodel/basemodel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index e97381884..26d51af3d 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel from lightllm.common.triton_utils.autotuner import Autotuner from lightllm.utils.infer_utils import post_empty_cache -from .attention import get_prefill_att_backend_class, get_decode_att_backend_class, get_vit_att_backend_class +from .attention import get_prefill_att_backend_class, get_decode_att_backend_class from .attention import BaseAttBackend logger = init_logger(__name__) @@ -119,6 +119,7 @@ def __init__(self, kvargs): self._init_custom() # wait必须在init cudagraph 之前,避免错误捕获 self._wait_other_modules_ready() + self._init_att_backend() self._init_att_backend1() From ee1d9aa3e6d52f322042c908bbd1945df95d3866 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:32:23 +0000 Subject: [PATCH 09/13] fix0126 --- .../basemodel/attention_vit/base_att.py | 4 +- .../basemodel/attention_vit/create_utils.py | 101 ++++++++++++++++++ .../common/basemodel/attention_vit/fa3/fp.py | 5 +- .../basemodel/attention_vit/sdpa/__init__.py | 0 .../common/basemodel/attention_vit/sdpa/fp.py | 48 +++++++++ .../basemodel/attention_vit/triton/fp.py | 10 +- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 4 +- lightllm/models/qwen2_vl/qwen2_visual.py | 5 +- lightllm/models/qwen3_vl/qwen3_visual.py | 5 - .../layer_infer/transformer_layer_infer.py | 4 +- lightllm/server/api_cli.py | 10 ++ lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/visualserver/__init__.py | 16 +++ .../visualserver/model_infer/model_rpc.py | 10 +- .../benchmark/static_inference/model_infer.py | 1 + 15 files changed, 192 insertions(+), 32 deletions(-) create mode 100644 lightllm/common/basemodel/attention_vit/create_utils.py create mode 100644 lightllm/common/basemodel/attention_vit/sdpa/__init__.py create mode 100644 lightllm/common/basemodel/attention_vit/sdpa/fp.py diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index e43475ac0..405aeb245 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -22,8 +22,8 @@ def __new__(cls, *args, **kwargs): # 返回已有的实例 return cls._instances[cls] - def __init__(self, model): - self.model = model + def __init__(self): + pass @abstractmethod def vit_att( diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py new file mode 100644 index 000000000..8983ea087 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -0,0 +1,101 @@ +import torch +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.backend_validator import _validate_triton, _compute_ground_truth +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend +from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend + +logger = init_logger(__name__) + + +vit_att_backend = {"triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend} + + +def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "sdpa", "triton"]) -> BaseVitAttBackend: + args = get_env_start_args() + backend_str = args.vit_att_backend[index] + if backend_str != "auto": + return vit_att_backend[backend_str] + else: + return _select_vit_backend(priority_list=priority_list) + + +def _select_vit_backend(priority_list: list = ["fa3", "sdpa", "triton"]) -> type: + """Auto-select the best available backend with validation for VIT. + + Priority: FA3 > Sdpa > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + backend_map = vit_att_backend + + for backend_name in priority_list: + if validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") + return backend_map[backend_name] + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded, falling back to triton") + return backend_map["triton"] + + +def validate(backend_name: str) -> bool: + if backend_name == "fa3": + validate_ok = _validate_fa3() + elif backend_name == "sdpa": + validate_ok = _validate_sdpa() + elif backend_name == "triton": + validate_ok = _validate_triton() + else: + raise ValueError("not suuported vit attn backend") + return validate_ok + + +def _validate_fa3(): + """Validate FA3 with ground truth.""" + from lightllm.utils.sgl_utils import flash_attn_varlen_func + + if flash_attn_varlen_func is None: + return False + + batch, heads, seq, dim = 1, 4, 8, 64 + q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + + expected = _compute_ground_truth(q, k, v) + + q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) + k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) + v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) + cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") + + out = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq, + max_seqlen_k=seq, + softmax_scale=1.0 / (dim ** 0.5), + causal=True, + ) + out = out.reshape(batch, seq, heads, dim).transpose(1, 2) + torch.cuda.synchronize() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False + + return True + + +def _validate_sdpa(): + """Validate SDPA Attn""" + from torch.nn.functional import scaled_dot_product_attention + + if scaled_dot_product_attention is None: + return False + + return True diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index ed1be1400..cec91165d 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,12 +1,9 @@ import dataclasses import torch -from ..base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend class Fa3VitAttBackend(BaseVitAttBackend): - def __init__(self, model): - super().__init__(model=model) - def vit_att( self, q: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/sdpa/__init__.py b/lightllm/common/basemodel/attention_vit/sdpa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py new file mode 100644 index 000000000..7f57d98a2 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class SdpaVitAttBackend(BaseVitAttBackend): + def vit_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + + cu = cu_seqlens.to(device=q.device) + B = cu.numel() - 1 + + with torch.no_grad(): + for b in range(B): + s = int(cu[b].item()) + e = int(cu[b + 1].item()) + L = e - s + if L <= 0: + continue + if max_seqlen: + assert L <= max_seqlen + + # [L, H, D] -> [1, H, L, D] + q_ = q[s:e].permute(1, 0, 2).unsqueeze(0) + k_ = k[s:e].permute(1, 0, 2).unsqueeze(0) + v_ = v[s:e].permute(1, 0, 2).unsqueeze(0) + + out = F.scaled_dot_product_attention( + q_, + k_, + v_, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + # [1, H, L, D] -> [L, H, D] + o[s:e].copy_(out.squeeze(0).permute(1, 0, 2)) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 2c012f5a2..51e47f056 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -1,9 +1,10 @@ import torch -from ..base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd class TritonVitAttBackend(BaseVitAttBackend): - def _vit_att( + def vit_att( self, q: torch.Tensor, k: torch.Tensor, @@ -11,10 +12,7 @@ def _vit_att( o: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, - alloc_func=torch.empty, ): - from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd - _flash_attention_triton_fwd( q, k, @@ -23,4 +21,4 @@ def _vit_att( cu_seqlens, # q k v cu_seqlens, max_seqlen, ) - return + return o diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 498f82e14..7156a5ce2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -13,7 +13,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -74,7 +74,7 @@ def forward( k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 7d53c326f..8cf23bd57 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -32,8 +32,8 @@ from transformers.activations import ACT2FN from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -127,7 +127,6 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) - self.vit_att_backend = None def forward( self, @@ -144,7 +143,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - self.vit_att_backend.vit_att(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index dd119f2de..e3e350729 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -181,11 +181,6 @@ def __init__( ) self._init_datatype() - def _init_vit_att(self, vit_att): - for blk in self.blocks: - blk.attn.vit_att_backend = vit_att - return - def _init_datatype(self): if isinstance(self.data_type, torch.dtype): return diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0d55d1b57..cada55e58 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -3,10 +3,10 @@ from typing import Tuple from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -105,7 +105,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor: q, k, v, out = map(reshape, (q, k, v, out)) cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len max_seqlen = seq_len - flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, out, cu_seqlens, max_seqlen) return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..199f0dccc 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -333,6 +333,16 @@ def make_argument_parser() -> argparse.ArgumentParser: auto: automatically select best backend based on GPU and available packages (priority: fa3 > flashinfer > triton)""", ) + parser.add_argument( + "--vit_att_backend", + type=str, + nargs="+", + choices=["auto", "triton", "fa3", "sdpa"], + default=["auto"], + help="""vit attention kernel used in vlm. + auto: automatically select best backend based on GPU and available packages + (priority: fa3 > sdpa > triton)""", + ) parser.add_argument( "--llm_kv_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 239cebfdd..0608d3972 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -121,6 +121,7 @@ class StartArgs: llm_decode_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) + vit_att_backend: List[str] = field(default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa"]}) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index e69de29bb..c8c108938 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -0,0 +1,16 @@ +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.create_utils import get_vit_att_backend_class + +VIT_ATTN_BACKEND: BaseVitAttBackend = None + + +def init_vit_att_backend(): + global VIT_ATTN_BACKEND + VIT_ATTN_BACKEND = get_vit_att_backend_class(index=0)() + return + + +def get_vit_attn_backend(): + if VIT_ATTN_BACKEND is None: + raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") + return VIT_ATTN_BACKEND.vit_att diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 79434f5b9..674c7a83d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,15 +24,11 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.common.basemodel.basemodel import BaseAttBackend, get_vit_att_backend_class +from lightllm.server.visualserver import init_vit_att_backend from lightllm.utils.dist_utils import set_global_rank class VisualModelRpcServer(rpyc.Service): - def _init_vit_att_backend(self): - self.vit_att_backend: BaseAttBackend = get_vit_att_backend_class(index=0)(model=self) - return - def exposed_init_model(self, kvargs): kvargs = obtain(kvargs) import torch @@ -48,7 +44,6 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - set_global_rank(kvargs["tp_rank_id"]) # 这里看后面怎么改 init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -87,8 +82,7 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") - self._init_vit_att_backend() - self.model._init_vit_att(self.vit_att_backend.create_vit_att_state()) + init_vit_att_backend() self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 7f1c2b493..919f379b9 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -43,6 +43,7 @@ def test_model_inference(args): "disable_cudagraph": args.disable_cudagraph, "llm_prefill_att_backend": args.llm_prefill_att_backend, "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, "llm_kv_type": args.llm_kv_type, "llm_kv_quant_group_size": args.llm_kv_quant_group_size, } From 31cf3b0299c24a06012e29502733c649ee60f19a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:34:51 +0000 Subject: [PATCH 10/13] fix0126 --- lightllm/models/qwen2_vl/qwen2_visual.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 8cf23bd57..0e2af0cbb 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -234,11 +234,6 @@ def _init_datatype(self): raise ValueError(f"Unsupport datatype {self.data_type}!") return - def _init_vit_att(self, vit_att): - for blk in self.blocks: - blk.attn.vit_att_backend = vit_att - return - def load_model(self, weight_dir): processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") From f0bfc1cd5e9b17fa10d10b14270b86df9c34c1f4 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:35:43 +0000 Subject: [PATCH 11/13] fix0126 --- lightllm/models/qwen3_vl/qwen3_visual.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index e3e350729..00ad6c05a 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -129,6 +129,7 @@ def __init__( ): super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") + self.depth = depth self.out_hidden_size = out_hidden_size self.hidden_size = hidden_size From 7f0fd35b292a0c6c4d0dcc26b07fd018a5c364a9 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:37:54 +0000 Subject: [PATCH 12/13] fix0126 --- lightllm/common/basemodel/attention_vit/base_att.py | 4 ++-- lightllm/common/basemodel/attention_vit/fa3/fp.py | 2 +- lightllm/common/basemodel/attention_vit/sdpa/fp.py | 2 +- lightllm/common/basemodel/attention_vit/triton/fp.py | 2 +- lightllm/server/visualserver/__init__.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py index 405aeb245..49bf6ad74 100644 --- a/lightllm/common/basemodel/attention_vit/base_att.py +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -4,7 +4,7 @@ class BaseVitAttBackend(ABC): """ - 用于创建支持各种不同的AttBackend, 如 fa3, flashinfer, triton 实现等, + 用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等, 这个是单列模式, 每种backend只有一个实例 """ @@ -26,7 +26,7 @@ def __init__(self): pass @abstractmethod - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index cec91165d..e77a0cec7 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -4,7 +4,7 @@ class Fa3VitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py index 7f57d98a2..9c7d5e311 100644 --- a/lightllm/common/basemodel/attention_vit/sdpa/fp.py +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -4,7 +4,7 @@ class SdpaVitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py index 51e47f056..c38a46633 100644 --- a/lightllm/common/basemodel/attention_vit/triton/fp.py +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -4,7 +4,7 @@ class TritonVitAttBackend(BaseVitAttBackend): - def vit_att( + def _vit_att_fwd( self, q: torch.Tensor, k: torch.Tensor, diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index c8c108938..026458447 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -13,4 +13,4 @@ def init_vit_att_backend(): def get_vit_attn_backend(): if VIT_ATTN_BACKEND is None: raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") - return VIT_ATTN_BACKEND.vit_att + return VIT_ATTN_BACKEND._vit_att_fwd From a5aa81782119b0da90390c5dd1705352fa5ac92a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 26 Jan 2026 10:46:52 +0000 Subject: [PATCH 13/13] fix0126 --- lightllm/server/visualserver/model_infer/model_rpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 674c7a83d..a81e89efe 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -45,6 +45,7 @@ def exposed_init_model(self, kvargs): self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] init_vision_distributed_env(kvargs) + init_vit_att_backend() model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: @@ -82,7 +83,6 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") - init_vit_att_backend() self.model.load_model(weight_dir) self.model = self.model.cuda() self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True)