Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
38 changes: 38 additions & 0 deletions lightllm/common/basemodel/attention_vit/base_att.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from abc import ABC, abstractmethod


class BaseVitAttBackend(ABC):
"""
用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等,
这个是单列模式, 每种backend只有一个实例
"""

_instances = {}

def __new__(cls, *args, **kwargs):
"""
重写__new__方法实现单例模式
"""
# 检查是否已经有该类的实例
if cls not in cls._instances:
# 创建新实例并存储
instance = super().__new__(cls)
cls._instances[cls] = instance
# 返回已有的实例
return cls._instances[cls]

def __init__(self):
pass

@abstractmethod
def _vit_att_fwd(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
raise NotImplementedError("not impl")
101 changes: 101 additions & 0 deletions lightllm/common/basemodel/attention_vit/create_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.backend_validator import _validate_triton, _compute_ground_truth
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend
from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend
from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend

logger = init_logger(__name__)


vit_att_backend = {"triton": TritonVitAttBackend, "sdpa": SdpaVitAttBackend, "fa3": Fa3VitAttBackend}


def get_vit_att_backend_class(index=0, priority_list: list = ["fa3", "sdpa", "triton"]) -> BaseVitAttBackend:
args = get_env_start_args()
backend_str = args.vit_att_backend[index]
if backend_str != "auto":
return vit_att_backend[backend_str]
else:
return _select_vit_backend(priority_list=priority_list)


def _select_vit_backend(priority_list: list = ["fa3", "sdpa", "triton"]) -> type:
"""Auto-select the best available backend with validation for VIT.

Priority: FA3 > Sdpa > Triton
Each backend is validated in a subprocess with ground truth checks.
"""
backend_map = vit_att_backend

for backend_name in priority_list:
if validate(backend_name):
logger.info(f"Auto-selected {backend_name} backend (validated) for VIT")
return backend_map[backend_name]

# Fallback to triton without validation (should not happen)
logger.warning("No backend validation succeeded, falling back to triton")
return backend_map["triton"]


def validate(backend_name: str) -> bool:
if backend_name == "fa3":
validate_ok = _validate_fa3()
elif backend_name == "sdpa":
validate_ok = _validate_sdpa()
elif backend_name == "triton":
validate_ok = _validate_triton()
else:
raise ValueError("not suuported vit attn backend")
return validate_ok


def _validate_fa3():
"""Validate FA3 with ground truth."""
from lightllm.utils.sgl_utils import flash_attn_varlen_func

if flash_attn_varlen_func is None:
return False

batch, heads, seq, dim = 1, 4, 8, 64
q = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda")
k = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda")
v = torch.randn(batch, heads, seq, dim, dtype=torch.bfloat16, device="cuda")

expected = _compute_ground_truth(q, k, v)

q_flat = q.transpose(1, 2).reshape(batch * seq, heads, dim)
k_flat = k.transpose(1, 2).reshape(batch * seq, heads, dim)
v_flat = v.transpose(1, 2).reshape(batch * seq, heads, dim)
cu_seqlens = torch.arange(0, batch * seq + 1, seq, dtype=torch.int32, device="cuda")

out = flash_attn_varlen_func(
q=q_flat,
k=k_flat,
v=v_flat,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq,
max_seqlen_k=seq,
softmax_scale=1.0 / (dim ** 0.5),
causal=True,
)
out = out.reshape(batch, seq, heads, dim).transpose(1, 2)
torch.cuda.synchronize()

if not torch.allclose(out, expected, rtol=1e-2, atol=1e-2):
return False

return True


def _validate_sdpa():
"""Validate SDPA Attn"""
from torch.nn.functional import scaled_dot_product_attention

if scaled_dot_product_attention is None:
return False

return True
Empty file.
57 changes: 57 additions & 0 deletions lightllm/common/basemodel/attention_vit/fa3/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import dataclasses
import torch
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend


class Fa3VitAttBackend(BaseVitAttBackend):
def _vit_att_fwd(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
) -> None:

head_dim = q.shape[-1]
softmax_scale = head_dim ** -0.5
window_size = (-1, -1)
torch.ops.sgl_kernel.fwd.default(
q,
k,
v,
None, # k_new
None, # v_new
None, # qv
o, # out
cu_seqlens,
cu_seqlens,
None, # cu_seqlens_k_new
None,
None,
max_seqlen,
max_seqlen,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
None,
None,
None,
softmax_scale,
False,
window_size[0],
window_size[1],
0.0,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0,
sinks=None,
)

return o
Empty file.
48 changes: 48 additions & 0 deletions lightllm/common/basemodel/attention_vit/sdpa/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.nn.functional as F
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend


class SdpaVitAttBackend(BaseVitAttBackend):
def _vit_att_fwd(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
assert q.ndim == k.ndim == v.ndim == o.ndim == 3
assert cu_seqlens is not None and cu_seqlens.ndim == 1

cu = cu_seqlens.to(device=q.device)
B = cu.numel() - 1

with torch.no_grad():
for b in range(B):
s = int(cu[b].item())
e = int(cu[b + 1].item())
L = e - s
if L <= 0:
continue
if max_seqlen:
assert L <= max_seqlen

# [L, H, D] -> [1, H, L, D]
q_ = q[s:e].permute(1, 0, 2).unsqueeze(0)
k_ = k[s:e].permute(1, 0, 2).unsqueeze(0)
v_ = v[s:e].permute(1, 0, 2).unsqueeze(0)

out = F.scaled_dot_product_attention(
q_,
k_,
v_,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
# [1, H, L, D] -> [L, H, D]
o[s:e].copy_(out.squeeze(0).permute(1, 0, 2))

return o
Empty file.
24 changes: 24 additions & 0 deletions lightllm/common/basemodel/attention_vit/triton/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
from lightllm.models.vit.triton_kernel.flashattention_nopad import _flash_attention_triton_fwd


class TritonVitAttBackend(BaseVitAttBackend):
def _vit_att_fwd(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int,
):
_flash_attention_triton_fwd(
q,
k,
v,
o,
cu_seqlens, # q k v cu_seqlens,
max_seqlen,
)
return o
4 changes: 2 additions & 2 deletions lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lightllm.server.multimodal_params import ImageItem
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton

Expand Down Expand Up @@ -74,7 +74,7 @@ def forward(
k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin)

attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
from transformers.activations import ACT2FN
from safetensors import safe_open
from lightllm.server.multimodal_params import ImageItem
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton

Expand Down Expand Up @@ -143,7 +143,7 @@ def forward(

attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)

flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
get_vit_attn_backend()(q, k, v, attn_output, cu_seqlens, max_seqlen)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from typing import Tuple
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


Expand Down Expand Up @@ -105,7 +105,7 @@ def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
q, k, v, out = map(reshape, (q, k, v, out))
cu_seqlens = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) * seq_len
max_seqlen = seq_len
flash_attention_fwd(q, k, v, out, cu_seqlens, max_seqlen)
get_vit_attn_backend()(q, k, v, out, cu_seqlens, max_seqlen)
return out.reshape(batch_size, seq_len, -1)

def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
Expand Down
10 changes: 10 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,16 @@ def make_argument_parser() -> argparse.ArgumentParser:
auto: automatically select best backend based on GPU and available packages
(priority: fa3 > flashinfer > triton)""",
)
parser.add_argument(
"--vit_att_backend",
type=str,
nargs="+",
choices=["auto", "triton", "fa3", "sdpa"],
default=["auto"],
help="""vit attention kernel used in vlm.
auto: automatically select best backend based on GPU and available packages
(priority: fa3 > sdpa > triton)""",
)
parser.add_argument(
"--llm_kv_type",
type=str,
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class StartArgs:
llm_decode_att_backend: List[str] = field(
default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "flashinfer"]}
)
vit_att_backend: List[str] = field(default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa"]})
llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv"]})
llm_kv_quant_group_size: int = field(default=8)
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})
Expand Down
16 changes: 16 additions & 0 deletions lightllm/server/visualserver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
from lightllm.common.basemodel.attention_vit.create_utils import get_vit_att_backend_class

VIT_ATTN_BACKEND: BaseVitAttBackend = None


def init_vit_att_backend():
global VIT_ATTN_BACKEND
VIT_ATTN_BACKEND = get_vit_att_backend_class(index=0)()
return


def get_vit_attn_backend():
if VIT_ATTN_BACKEND is None:
raise RuntimeError("VIT_ATTN_BACKEND is not initialized. Call init_vit_att_backend() first.")
return VIT_ATTN_BACKEND._vit_att_fwd
4 changes: 3 additions & 1 deletion lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
from lightllm.server.visualserver import init_vit_att_backend
from lightllm.utils.dist_utils import set_global_rank


class VisualModelRpcServer(rpyc.Service):
Expand All @@ -42,8 +44,8 @@ def exposed_init_model(self, kvargs):
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.data_type = kvargs["data_type"]

init_vision_distributed_env(kvargs)
init_vit_att_backend()
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)

try:
Expand Down
1 change: 1 addition & 0 deletions test/benchmark/static_inference/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_model_inference(args):
"disable_cudagraph": args.disable_cudagraph,
"llm_prefill_att_backend": args.llm_prefill_att_backend,
"llm_decode_att_backend": args.llm_decode_att_backend,
"vit_att_backend": args.vit_att_backend,
"llm_kv_type": args.llm_kv_type,
"llm_kv_quant_group_size": args.llm_kv_quant_group_size,
}
Expand Down