diff --git a/lightllm/common/basemodel/attention_vit/__init__.py b/lightllm/common/basemodel/attention_vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/base_att.py b/lightllm/common/basemodel/attention_vit/base_att.py new file mode 100644 index 000000000..49bf6ad74 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/base_att.py @@ -0,0 +1,38 @@ +import torch +from abc import ABC, abstractmethod + + +class BaseVitAttBackend(ABC): + """ + 用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等, + 这个是单列模式, 每种backend只有一个实例 + """ + + _instances = {} + + def __new__(cls, *args, **kwargs): + """ + 重写__new__方法实现单例模式 + """ + # 检查是否已经有该类的实例 + if cls not in cls._instances: + # 创建新实例并存储 + instance = super().__new__(cls) + cls._instances[cls] = instance + # 返回已有的实例 + return cls._instances[cls] + + def __init__(self): + pass + + @abstractmethod + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + raise NotImplementedError("not impl") diff --git a/lightllm/common/basemodel/attention_vit/create_utils.py b/lightllm/common/basemodel/attention_vit/create_utils.py new file mode 100644 index 000000000..3aa23f16d --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/create_utils.py @@ -0,0 +1,122 @@ +import torch +from lightllm.utils.log_utils import init_logger +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.backend_validator import _validate_triton, _compute_ground_truth +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend +from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend +from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend +from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend + +logger = init_logger(__name__) + + +vit_att_backend = { + "triton": TritonVitAttBackend, + "sdpa": SdpaVitAttBackend, + "fa3": Fa3VitAttBackend, + "xformers": XformersVitAttBackend, +} + + +def get_vit_att_backend_class( + index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"] +) -> BaseVitAttBackend: + args = get_env_start_args() + backend_str = args.vit_att_backend[index] + if backend_str != "auto": + logger.info(f"Selected {backend_str} backend for VIT") + return vit_att_backend[backend_str] + else: + return _select_vit_backend(priority_list=priority_list) + + +def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> type: + """Auto-select the best available backend with validation for VIT. + + Priority: FA3 > Xformers > Sdpa > Triton + Each backend is validated in a subprocess with ground truth checks. + """ + backend_map = vit_att_backend + + for backend_name in priority_list: + if validate(backend_name): + logger.info(f"Auto-selected {backend_name} backend (validated) for VIT") + return backend_map[backend_name] + + # Fallback to triton without validation (should not happen) + logger.warning("No backend validation succeeded, falling back to triton") + return backend_map["triton"] + + +def validate(backend_name: str) -> bool: + if backend_name == "fa3": + validate_ok = _validate_fa3() + elif backend_name == "xformers": + validate_ok = _validate_xformers() + elif backend_name == "sdpa": + validate_ok = _validate_sdpa() + elif backend_name == "triton": + validate_ok = _validate_triton() + else: + raise ValueError("not suuported vit attn backend") + return validate_ok + + +def _validate_fa3(): + """Validate FA3 with ground truth.""" + from lightllm.utils.sgl_utils import flash_attn_varlen_func + + if flash_attn_varlen_func is None: + return False + + batch, heads, seq, dim = 1, 4, 8, 64 + q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda") + + expected = _compute_ground_truth(q, k, v) + + q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim) + k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim) + v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim) + cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda") + + out = flash_attn_varlen_func( + q=q_flat, + k=k_flat, + v=v_flat, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq, + max_seqlen_k=seq, + softmax_scale=1.0 / (dim ** 0.5), + causal=True, + ) + out = out.reshape(batch, seq, heads, dim).transpose(1, 2) + torch.cuda.synchronize() + + if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2): + return False + + return True + + +def _validate_xformers(): + """Validate Xformers Attn""" + from xformers import ops as xformers_ops + + if xformers_ops is None: + return False + + return True + + +def _validate_sdpa(): + """Validate SDPA Attn""" + from torch.nn.functional import scaled_dot_product_attention + + if scaled_dot_product_attention is None: + return False + + return True diff --git a/lightllm/common/basemodel/attention_vit/fa3/__init__.py b/lightllm/common/basemodel/attention_vit/fa3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py new file mode 100644 index 000000000..e77a0cec7 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -0,0 +1,57 @@ +import dataclasses +import torch +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class Fa3VitAttBackend(BaseVitAttBackend): + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> None: + + head_dim = q.shape[-1] + softmax_scale = head_dim ** -0.5 + window_size = (-1, -1) + torch.ops.sgl_kernel.fwd.default( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + o, # out + cu_seqlens, + cu_seqlens, + None, # cu_seqlens_k_new + None, + None, + max_seqlen, + max_seqlen, + None, # page_table, + None, # kv_batch_idx + None, # leftpad_k + None, # rotary cos + None, # rotary sin + None, # seqlens_rotary + None, + None, + None, + softmax_scale, + False, + window_size[0], + window_size[1], + 0.0, + is_rotary_interleaved=False, + scheduler_metadata=None, + num_splits=1, + pack_gqa=None, + sm_margin=0, + sinks=None, + ) + + return o diff --git a/lightllm/common/basemodel/attention_vit/sdpa/__init__.py b/lightllm/common/basemodel/attention_vit/sdpa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/sdpa/fp.py b/lightllm/common/basemodel/attention_vit/sdpa/fp.py new file mode 100644 index 000000000..aea8b19bb --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/sdpa/fp.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class SdpaVitAttBackend(BaseVitAttBackend): + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + + B = cu_seqlens.numel() - 1 + + with torch.no_grad(): + for b in range(B): + s = int(cu_seqlens[b].item()) + e = int(cu_seqlens[b + 1].item()) + L = e - s + if L <= 0: + continue + if max_seqlen: + assert L <= max_seqlen + + # [L, H, D] -> [1, H, L, D] + q_ = q[s:e].permute(1, 0, 2).unsqueeze(0) + k_ = k[s:e].permute(1, 0, 2).unsqueeze(0) + v_ = v[s:e].permute(1, 0, 2).unsqueeze(0) + + out = F.scaled_dot_product_attention( + q_, + k_, + v_, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + # [1, H, L, D] -> [L, H, D] + o[s:e].copy_(out.squeeze(0).permute(1, 0, 2)) + + return o diff --git a/lightllm/common/basemodel/attention_vit/triton/__init__.py b/lightllm/common/basemodel/attention_vit/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/triton/fp.py b/lightllm/common/basemodel/attention_vit/triton/fp.py new file mode 100644 index 000000000..c38a46633 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/triton/fp.py @@ -0,0 +1,24 @@ +import torch +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd + + +class TritonVitAttBackend(BaseVitAttBackend): + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ): + _flash_attention_triton_fwd( + q, + k, + v, + o, + cu_seqlens, # q k v cu_seqlens, + max_seqlen, + ) + return o diff --git a/lightllm/common/basemodel/attention_vit/xformers/__init__.py b/lightllm/common/basemodel/attention_vit/xformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py new file mode 100644 index 000000000..be1d33948 --- /dev/null +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F +from xformers import ops as xformers_ops +from xformers.ops import fmha +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend + + +class XformersVitAttBackend(BaseVitAttBackend): + @torch.no_grad() + def _vit_att_fwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + ) -> torch.Tensor: + assert q.ndim == k.ndim == v.ndim == o.ndim == 3 + assert cu_seqlens is not None and cu_seqlens.ndim == 1 + assert q.shape == k.shape == v.shape == o.shape + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).to(torch.int64).tolist() + seqlens = [int(L) for L in seqlens if int(L) > 0] + + if len(seqlens) == 0: + return o + if max_seqlen: + assert max(seqlens) <= max_seqlen + + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) + + q_ = q.unsqueeze(0) # [1, T, H, D] + k_ = k.unsqueeze(0) # [1, T, H, D] + v_ = v.unsqueeze(0) # [1, T, H, D] + + out = xformers_ops.memory_efficient_attention(q_, k_, v_, attn_bias=attn_bias, p=0.0) + o.copy_(out.squeeze(0)) # [T, H, D] + return o diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 498f82e14..7156a5ce2 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -13,7 +13,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -74,7 +74,7 @@ def forward( k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin) attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 334ffc844..0e2af0cbb 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -32,8 +32,8 @@ from transformers.activations import ACT2FN from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -143,7 +143,7 @@ def forward( attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device) - flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output diff --git a/lightllm/models/vit/layer_infer/transformer_layer_infer.py b/lightllm/models/vit/layer_infer/transformer_layer_infer.py index 0d55d1b57..cada55e58 100644 --- a/lightllm/models/vit/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/vit/layer_infer/transformer_layer_infer.py @@ -3,10 +3,10 @@ from typing import Tuple from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight -from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm +from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager @@ -105,7 +105,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor: q, k, v, out = map(reshape, (q, k, v, out)) cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len max_seqlen = seq_len - flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen) + get_vit_attn_backend()(q, k, v, out, cu_seqlens, max_seqlen) return out.reshape(batch_size, seq_len, -1) def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..b6fffce4e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -333,6 +333,16 @@ def make_argument_parser() -> argparse.ArgumentParser: auto: automatically select best backend based on GPU and available packages (priority: fa3 > flashinfer > triton)""", ) + parser.add_argument( + "--vit_att_backend", + type=str, + nargs="+", + choices=["auto", "triton", "fa3", "sdpa", "xformers"], + default=["auto"], + help="""vit attention kernel used in vlm. + auto: automatically select best backend based on GPU and available packages + (priority: fa3 > xformers > sdpa > triton)""", + ) parser.add_argument( "--llm_kv_type", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 239cebfdd..d914a1736 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -121,6 +121,9 @@ class StartArgs: llm_decode_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]} ) + vit_att_backend: List[str] = field( + default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} + ) llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]}) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) diff --git a/lightllm/server/visualserver/__init__.py b/lightllm/server/visualserver/__init__.py index e69de29bb..026458447 100644 --- a/lightllm/server/visualserver/__init__.py +++ b/lightllm/server/visualserver/__init__.py @@ -0,0 +1,16 @@ +from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.common.basemodel.attention_vit.create_utils import get_vit_att_backend_class + +VIT_ATTN_BACKEND: BaseVitAttBackend = None + + +def init_vit_att_backend(): + global VIT_ATTN_BACKEND + VIT_ATTN_BACKEND = get_vit_att_backend_class(index=0)() + return + + +def get_vit_attn_backend(): + if VIT_ATTN_BACKEND is None: + raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.") + return VIT_ATTN_BACKEND._vit_att_fwd diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d3d1610f3..a81e89efe 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,6 +24,8 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.server.visualserver import init_vit_att_backend +from lightllm.utils.dist_utils import set_global_rank class VisualModelRpcServer(rpyc.Service): @@ -42,8 +44,8 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] - init_vision_distributed_env(kvargs) + init_vit_att_backend() model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) try: diff --git a/requirements.txt b/requirements.txt index 8d9a011be..a3b9473f8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,6 +87,7 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 +xformers==0.0.32.post1 xxhash==3.6.0 torchvision==0.23.0 interegular==0.3.3 diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index 7f1c2b493..919f379b9 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -43,6 +43,7 @@ def test_model_inference(args): "disable_cudagraph": args.disable_cudagraph, "llm_prefill_att_backend": args.llm_prefill_att_backend, "llm_decode_att_backend": args.llm_decode_att_backend, + "vit_att_backend": args.vit_att_backend, "llm_kv_type": args.llm_kv_type, "llm_kv_quant_group_size": args.llm_kv_quant_group_size, }