-
Notifications
You must be signed in to change notification settings - Fork 634
[ascend] refactor code #4176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[ascend] refactor code #4176
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
e388a94
refactor ascend op_backend
yao-fengchen 080dc12
refactor mask
yao-fengchen 0840a82
format code
yao-fengchen 24cb832
remove 310P judge
yao-fengchen 9a9604a
remove unused code
yao-fengchen bd877fe
update code
yao-fengchen 74501d6
update ascend build_graph_runner
yao-fengchen 2a945c7
remove old version of ascend graph_runner
yao-fengchen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
| import itertools | ||
| import os | ||
| import re | ||
| from functools import lru_cache | ||
| from pathlib import Path | ||
| from typing import Dict, Tuple | ||
|
|
||
|
|
@@ -13,34 +12,22 @@ | |
| from lmdeploy.utils import get_logger | ||
|
|
||
| from ..op_backend import DlinferOpsBackend | ||
| from .utils import nd_to_nz_spec | ||
|
|
||
| logger = get_logger('lmdeploy') | ||
|
|
||
|
|
||
| class SocVersion: | ||
| Ascend310P: str = 'Ascend310P' | ||
| Ascend910: str = 'Ascend910' | ||
|
|
||
| @classmethod | ||
| @lru_cache(maxsize=1) | ||
| def device_name(cls) -> str: | ||
| try: | ||
| import torch_npu | ||
| return torch_npu.npu.get_device_name() | ||
| except ImportError: | ||
| logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ') | ||
| except Exception as e: | ||
| logger.warning(f'Error during Ascend get device name: {str(e)}. ' | ||
| 'Please check your Ascend environment configuration.') | ||
| device_name: str = torch.npu.get_device_name() | ||
|
|
||
| @classmethod | ||
| def is_Ascend310P(cls) -> bool: | ||
| return cls.device_name().startswith(cls.Ascend310P) | ||
| return cls.device_name.startswith(cls.Ascend310P) | ||
|
|
||
| @classmethod | ||
| def is_Ascend910(cls) -> bool: | ||
| return cls.device_name().startswith(cls.Ascend910) | ||
| return cls.device_name.startswith(cls.Ascend910) | ||
|
|
||
|
|
||
| class AscendKVQuantMeta: | ||
|
|
@@ -108,12 +95,6 @@ def get_k_block_shape( | |
| ) -> Tuple[int, ...]: | ||
| if SocVersion.is_Ascend910(): | ||
| return (block_size, num_heads, head_size) | ||
| elif SocVersion.is_Ascend310P(): | ||
| return ( | ||
| (num_heads * head_size + 15) // 16, | ||
| block_size, | ||
| 16, | ||
| ) | ||
| else: | ||
| raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') | ||
yao-fengchen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
@@ -126,160 +107,116 @@ def get_v_block_shape( | |
| ) -> Tuple[int, ...]: | ||
| if SocVersion.is_Ascend910(): | ||
| return (block_size, num_heads, head_size) | ||
| elif SocVersion.is_Ascend310P(): | ||
| return ( | ||
| (num_heads * head_size + 15) // 16, | ||
| block_size, | ||
| 16, | ||
| ) | ||
| else: | ||
| raise ValueError(f'dlinfer does not support {SocVersion.device_name()} device currently.') | ||
yao-fengchen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @classmethod | ||
| @lru_cache(maxsize=1) | ||
| def enable_aclgraph(cls) -> bool: | ||
| if os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'aclgraph' and not SocVersion.is_Ascend310P(): | ||
| return True | ||
| elif os.getenv('ASCEND_GRAPH_MODE', 'aclgraph') == 'atbgraph' or SocVersion.is_Ascend310P(): | ||
| return False | ||
| else: | ||
| raise ValueError(f"unsupported ASCEND_GRAPH_MODE: {os.getenv('ASCEND_GRAPH_MODE')}") | ||
|
|
||
| @classmethod | ||
| def update_step_context(cls, step_context): | ||
| """Update step context.""" | ||
|
|
||
| kv_start_indices, attention_mask = [], [] | ||
yao-fengchen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
yao-fengchen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| block_num, block_size, *_ = step_context.kv_caches[0][0].shape | ||
| is_unpaged_prefill = False | ||
| if not step_context.is_decoding: | ||
| is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) | ||
| if step_context.block_offsets.dtype != torch.int32: | ||
| step_context.block_offsets = step_context.block_offsets.to(torch.int32) | ||
| if not (step_context.is_decoding or is_unpaged_prefill): | ||
| step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0) | ||
| if step_context.kv_seqlens.dtype != torch.int32: | ||
| step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) | ||
| if step_context.q_seqlens.dtype != torch.int32: | ||
| step_context.q_seqlens = step_context.q_seqlens.to(torch.int32) | ||
|
|
||
| def get_total_slots(): | ||
| if cls.total_slots is None: | ||
| cls.total_slots = torch.arange(block_num * block_size, | ||
| dtype=torch.long, | ||
| dtype=torch.int32, | ||
| device=step_context.block_offsets.device) | ||
| cls.total_slots = cls.total_slots.view(block_num, block_size) | ||
| return cls.total_slots | ||
|
|
||
| kv_start_indices, attention_mask = [], [] | ||
| if SocVersion.is_Ascend910(): | ||
| block_num, block_size, *_ = step_context.kv_caches[0][0].shape | ||
| elif SocVersion.is_Ascend310P(): | ||
| block_num, _, block_size, _ = step_context.kv_caches[0][0].shape | ||
| def get_cpu_seqlens(is_decoding, is_unpaged_prefill): | ||
| if is_decoding: | ||
| q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu() | ||
| elif is_unpaged_prefill: | ||
| q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu() | ||
| else: | ||
| q_seqlens_cpu = step_context.q_seqlens.cpu() | ||
| kv_seqlens_cpu = step_context.kv_seqlens.cpu() | ||
| return q_seqlens_cpu, kv_seqlens_cpu | ||
|
|
||
| is_unpaged_prefill = False | ||
| if not step_context.is_decoding: | ||
| is_unpaged_prefill = \ | ||
| all((step_context.q_seqlens == | ||
| step_context.kv_seqlens).tolist()) | ||
| q_seqlens_list = step_context.q_seqlens.tolist() | ||
| kv_seqlens_list = step_context.kv_seqlens.tolist() | ||
| max_q_seq_len = max(q_seqlens_list) | ||
| max_kv_seq_len = max(kv_seqlens_list) | ||
|
|
||
| if step_context.is_decoding: | ||
| # collect kv_start_indices without using a for-loop, | ||
| # (fill kv-cache for just ONE token during the decoding phase) | ||
| idx = (step_context.kv_seqlens - 1) % block_size | ||
| block_num = (step_context.kv_seqlens - 1) // block_size | ||
| last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) | ||
| kv_start_indices = last_block * block_size + idx | ||
| else: | ||
| for i in range(step_context.q_start_loc.size(0)): | ||
| q_seq_len = q_seqlens_list[i] | ||
| kv_seq_len = kv_seqlens_list[i] | ||
|
|
||
| # collect kv start indices during the prefill phase. | ||
| history_length = kv_seq_len - q_seq_len | ||
| total_slots = get_total_slots() | ||
| slot_tables = total_slots[step_context.block_offsets[i]].view(-1) | ||
| slots = slot_tables[history_length:kv_seq_len] | ||
| kv_start_indices.append(slots) | ||
|
|
||
| # collect attention mask of paged_prefill attention stage. | ||
| if not is_unpaged_prefill: | ||
| single_attention_mask = torch.logical_not( | ||
| torch.tril( | ||
| def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): | ||
| if is_decoding: | ||
| q_seqlens_list, kv_seqlens_list = None, None | ||
| elif is_unpaged_prefill: | ||
| q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist() | ||
| else: | ||
| q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist() | ||
| return q_seqlens_list, kv_seqlens_list | ||
|
|
||
| def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None): | ||
| if is_decoding: | ||
| max_q_seq_len, max_kv_seq_len = 1, None | ||
| elif is_unpaged_prefill: | ||
| max_q_seq_len = max_kv_seq_len = max(q_seqlens_list) | ||
| else: | ||
| max_q_seq_len = max(q_seqlens_list) | ||
| max_kv_seq_len = max(kv_seqlens_list) | ||
| return max_q_seq_len, max_kv_seq_len | ||
|
|
||
| def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, | ||
| max_q_seq_len, max_kv_seq_len): | ||
| kv_start_indices, attention_mask = [], [] | ||
| if is_decoding: | ||
| idx = (step_context.kv_seqlens - 1) % block_size | ||
| block_num = (step_context.kv_seqlens - 1) // block_size | ||
| last_block = step_context.block_offsets.gather(1, block_num.view(-1, 1)).view(-1) | ||
| kv_start_indices = last_block * block_size + idx | ||
| else: | ||
| for i in range(step_context.q_start_loc.size(0)): | ||
| q_seq_len = q_seqlens_list[i] | ||
| kv_seq_len = kv_seqlens_list[i] | ||
|
|
||
| history_length = kv_seq_len - q_seq_len | ||
| total_slots = get_total_slots() | ||
| slot_tables = total_slots[step_context.block_offsets[i]].view(-1) | ||
| slots = slot_tables[history_length:kv_seq_len] | ||
| kv_start_indices.append(slots) | ||
|
|
||
| if not is_unpaged_prefill: | ||
| single_attention_mask = torch.triu( | ||
| torch.ones(q_seq_len, | ||
| step_context.block_offsets.shape[1] * block_size, | ||
| dtype=torch.bool, | ||
| device=step_context.block_offsets.device), | ||
| diagonal=kv_seq_len - q_seq_len, | ||
| )) | ||
| attention_mask.append(single_attention_mask) | ||
|
|
||
| kv_start_indices = torch.cat(kv_start_indices) | ||
|
|
||
| if step_context.is_decoding: | ||
| # prepare some params of paged_decode attention stage. | ||
| q_start_loc_cpu, q_seqlens_cpu = None, None | ||
| elif is_unpaged_prefill: | ||
| # prepare some params of unpaged_prefill attention stage. | ||
| q_start_loc_cpu, kv_seqlens_cpu = None, None | ||
| q_seqlens_cpu = step_context.q_seqlens.cpu().to(torch.int32) | ||
| if SocVersion.is_Ascend910(): | ||
| single_attention_mask = torch.logical_not( | ||
| torch.tril( | ||
| torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.bool).cuda(), | ||
| diagonal=max_kv_seq_len - max_q_seq_len, | ||
| )) | ||
| attention_mask.append(single_attention_mask) | ||
| elif SocVersion.is_Ascend310P(): | ||
| if not cls.enable_graph: | ||
| for i in range(q_seqlens_cpu.size(0)): | ||
| single_attention_mask = torch.zeros(q_seqlens_cpu[i], | ||
| q_seqlens_cpu[i]).fill_(-float('inf')).cuda() | ||
| single_attention_mask = torch.triu(single_attention_mask, diagonal=1) | ||
| diagonal=kv_seq_len - q_seq_len + 1, | ||
| ) | ||
| attention_mask.append(single_attention_mask) | ||
| else: | ||
| # Transdata needs dtype to be float16 or int8 | ||
| single_attention_mask = torch.triu( | ||
| torch.ones(max_q_seq_len, max_kv_seq_len, dtype=torch.float16).fill_(-float('inf')).cuda(), | ||
| diagonal=max_kv_seq_len - max_q_seq_len + 1, | ||
| ) | ||
| # Convert to NZ format | ||
| attention_mask.append(nd_to_nz_spec(single_attention_mask)) | ||
| else: | ||
| raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") | ||
| else: | ||
| # prepare some params of paged_prefill attention stage. | ||
| q_start_loc_cpu, q_seqlens_cpu = None, None | ||
| attention_mask = [torch.cat([mask for mask in attention_mask])] | ||
|
|
||
| if cls.enable_graph: | ||
| kv_start_indices = kv_start_indices.view(-1).to(torch.int32) | ||
| import torch._dynamo as dynamo | ||
| if not is_unpaged_prefill: | ||
| step_context.block_offsets = step_context.block_offsets.to(torch.int32) | ||
| if not step_context.is_decoding: | ||
| step_context.block_offsets = step_context.block_offsets\ | ||
| .repeat_interleave(step_context.q_seqlens, 0) | ||
| dynamo.mark_dynamic(step_context.block_offsets, [0, 1]) | ||
| kv_seqlens = step_context.kv_seqlens.cpu().to(torch.int32) | ||
| if not step_context.is_decoding: | ||
|
|
||
| if is_unpaged_prefill: | ||
| if SocVersion.is_Ascend910(): | ||
| attention_mask = [mask.half() for mask in attention_mask] | ||
| attention_mask.append( | ||
| torch.triu(torch.ones(max_q_seq_len, | ||
| max_kv_seq_len, | ||
| dtype=step_context.kv_caches[0][0].dtype, | ||
| device=step_context.block_offsets.device), | ||
| diagonal=max_kv_seq_len - max_q_seq_len + 1)) | ||
| else: | ||
| if SocVersion.is_Ascend910(): | ||
| attention_mask = [ | ||
| torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask]).unsqueeze(1) | ||
| ] | ||
| elif SocVersion.is_Ascend310P(): | ||
| # Convert mask to NZ format. | ||
| attention_mask = [ | ||
| nd_to_nz_spec(torch.cat([mask.half() * cls.half_negative_inf for mask in attention_mask])) | ||
| ] | ||
| else: | ||
| raise ValueError(f"dlinfer doesn't support {SocVersion.device_name()} device currently.") | ||
| kv_seqlens = kv_seqlens.repeat_interleave(step_context.q_seqlens, 0) | ||
| else: | ||
| if step_context.is_decoding: | ||
| kv_seqlens_cpu = step_context.kv_seqlens.cpu().to(torch.int32) | ||
| elif is_unpaged_prefill: | ||
| pass | ||
| else: | ||
| kv_seqlens_cpu = step_context.kv_seqlens.repeat_interleave(step_context.q_seqlens, 0).cpu() | ||
| block_offsets_int32 = step_context.block_offsets.to(torch.int32) | ||
| step_context.block_offsets = block_offsets_int32\ | ||
| .repeat_interleave(step_context.q_seqlens, 0) | ||
| kv_seqlens = kv_seqlens_cpu | ||
| attention_mask = [torch.cat(attention_mask)] | ||
|
|
||
| kv_start_indices = torch.cat(kv_start_indices) | ||
|
|
||
| return kv_start_indices, attention_mask | ||
|
|
||
| q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) | ||
| q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, | ||
| kv_seqlens_cpu) | ||
| max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, | ||
| kv_seqlens_list) | ||
| kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, | ||
| is_unpaged_prefill, q_seqlens_list, | ||
| kv_seqlens_list, max_q_seq_len, | ||
| max_kv_seq_len) | ||
|
|
||
| if not cls.enable_graph and step_context.kv_quant_policy == 8: | ||
| record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') | ||
|
|
@@ -298,9 +235,9 @@ def get_total_slots(): | |
| attn_metadata = attn_meta_cls( | ||
| step_context.is_decoding, | ||
| step_context.block_offsets, | ||
| q_start_loc=q_start_loc_cpu, | ||
| q_start_loc=None, | ||
| q_seqlens=q_seqlens_cpu, | ||
| kv_seqlens=kv_seqlens, | ||
| kv_seqlens=kv_seqlens_cpu, | ||
| kv_start_indices=kv_start_indices, | ||
| block_size=block_size, | ||
| attention_mask=attention_mask, | ||
|
|
@@ -318,23 +255,14 @@ def get_total_slots(): | |
| def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig, | ||
| backend_config: BackendConfig, device: torch.device): | ||
| """Build graph runner.""" | ||
| if AscendOpsBackend.enable_aclgraph(): | ||
| from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner | ||
| return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) | ||
| else: | ||
| from .graph_runner import AscendGraphRunner | ||
| ascend_graph_runner = AscendGraphRunner(model, model_config, cache_config, backend_config, device) | ||
| AscendOpsBackend.enable_graph = ascend_graph_runner.enable_graph | ||
| return ascend_graph_runner | ||
| from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner | ||
| return CUDAGraphRunner(model, model_config, cache_config, backend_config, device) | ||
|
||
|
|
||
| @staticmethod | ||
| def init(): | ||
| """Initialize Ascend backend.""" | ||
| try: | ||
| from torch_npu.contrib import transfer_to_npu # noqa: F401 | ||
| if SocVersion.is_Ascend310P(): | ||
| # NOTE: Ascend310P has a bug with InternVL vision embedding using interpolate. | ||
| torch.npu.set_compile_mode(jit_compile=False) | ||
| except ImportError: | ||
| logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' | ||
| 'Ascend initialization skipped.') | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.