diff --git a/lightllm/__init__.py b/lightllm/__init__.py index e69de29bb..e9ba6f304 100644 --- a/lightllm/__init__.py +++ b/lightllm/__init__.py @@ -0,0 +1,4 @@ +from lightllm.utils.device_utils import is_musa + +if is_musa(): + import torchada # noqa: F401 diff --git a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py index eb5af6fec..45de83e98 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @@ -60,7 +60,8 @@ def _fwd_kernel_token_att1( ).to(tl.int64) off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1, dtype=tl.float32) + att_value = tl.sum(q[None, :] * k, 1) + att_value = att_value.to(tl.float32) att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 7b9cdd501..de6d6ba20 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -1,7 +1,7 @@ import time import uuid -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing import Any, Dict, List, Optional, Union, Literal, ClassVar from transformers import GenerationConfig @@ -21,6 +21,14 @@ class Message(BaseModel): content: Union[str, List[MessageContent]] +class CharacterMessage(BaseModel): + """Message format for character-based chat, where role is inferred from name.""" + + name: str + content: Union[str, List[MessageContent]] + role: Optional[str] = None # Optional, can be inferred from role_setting + + class Function(BaseModel): """Function descriptions.""" @@ -105,7 +113,7 @@ def _normalize_role(cls, v): raise ValueError("'role' must be a string") -ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message] +ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message, CharacterMessage] class CompletionRequest(BaseModel): @@ -176,6 +184,8 @@ def apply_loaded_defaults(cls, data: Any): class ChatCompletionRequest(BaseModel): + model_config = ConfigDict(populate_by_name=True) + model: str messages: List[ChatCompletionMessageParam] function_call: Optional[str] = "none" @@ -216,8 +226,9 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[int] = -1 repetition_penalty: Optional[float] = 1.0 ignore_eos: Optional[bool] = False - role_settings: Optional[Dict[str, str]] = None + role_settings: Optional[Dict[str, str]] = Field(default=None, alias="role_setting") character_settings: Optional[List[Dict[str, str]]] = None + system_instruction: Optional[str] = None # Class variables to store loaded default values _loaded_defaults: ClassVar[Dict[str, Any]] = {} diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 6a8c232dc..cdac4ab5c 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -105,7 +105,8 @@ def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int: messages = getattr(request, "messages", []) idx = 0 for msg in messages: - if msg.role == "assistant": + role = getattr(msg, "role", None) + if role == "assistant": tool_calls = getattr(msg, "tool_calls", None) idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa return idx diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index f770459a5..cff2ec127 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -22,6 +22,7 @@ async def build_prompt(request, tools) -> str: kwargs["character_settings"] = request.character_settings if request.role_settings: kwargs["role_setting"] = request.role_settings + kwargs["system_instruction"] = request.system_instruction if request.chat_template_kwargs: kwargs.update(request.chat_template_kwargs) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index cd48a355b..09d7a680f 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -81,11 +81,14 @@ def calcu_kernel_best_vsm_count(kernel, num_warps): return num_sm +@lru_cache(maxsize=1) +def is_musa(): + return hasattr(torch.version, "musa") and torch.version.musa is not None + + @lru_cache(maxsize=None) def get_current_device_name(): - import torch - - if torch.cuda.is_available(): + if torch.cuda.is_available() or is_musa(): device = torch.cuda.current_device() gpu_name = torch.cuda.get_device_name(device) # 4090 trans to 4090 D @@ -103,8 +106,6 @@ def init_p2p(device_index): """ torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。 """ - import torch - num_gpus = torch.cuda.device_count() tensor = torch.zeros((1,)) tensor = tensor.to(f"cuda:{device_index}") @@ -127,8 +128,26 @@ def has_nvlink(): result = result.decode("utf-8") # Check if the output contains 'NVLink' return any(f"NV{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # nvidia-smi is not installed, assume no NVLink + return False + except subprocess.CalledProcessError: + # If there's an error while executing nvidia-smi, assume no NVLink + return False + + +def has_mtlink(): + try: + # Call mthreads-gmi to get the topology matrix + result = subprocess.check_output(["mthreads-gmi", "topo", "--matrix"]) + result = result.decode("utf-8") + # Check if the output contains 'MTLink' + return any(f"MT{i}" in result for i in range(1, 8)) + except FileNotFoundError: + # mthreads-gmi is not installed, assume no MTLink + return False except subprocess.CalledProcessError: - # If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink + # If there's an error while executing mthreads-gmi, assume no MTLink return False