@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " attention_mask_impl .h"
16+ #include " attention_mask .h"
1717
1818namespace xllm {
1919namespace layer {
2020
21- AttentionMaskImpl::AttentionMaskImpl (at::Device device,
22- torch::Dtype dtype,
23- float mask_value) {
21+ AttentionMask::AttentionMask (at::Device device,
22+ torch::Dtype dtype,
23+ float mask_value) {
2424 int max_seq_len = 128 ;
2525 seq_len_cached_ = max_seq_len;
2626 auto bias_cache =
@@ -37,25 +37,24 @@ AttentionMaskImpl::AttentionMaskImpl(at::Device device,
3737 .to (device);
3838}
3939
40- torch::Tensor AttentionMaskImpl::get_decode_attn_mask (
41- torch::Tensor input_lengths,
42- int64_t max_s,
43- torch::Dtype dtype,
44- torch::Device device) {
40+ torch::Tensor AttentionMask::get_decode_attn_mask (torch::Tensor input_lengths,
41+ int64_t max_s,
42+ torch::Dtype dtype,
43+ torch::Device device) {
4544 update_attn_cache (dtype, device, max_s);
4645 return atten_mask_cache_.index_select (0 , input_lengths).view ({-1 , 1 , max_s});
4746}
4847
49- torch::Tensor AttentionMaskImpl ::get_attn_mask (int64_t max_s,
50- torch::Dtype dtype,
51- torch::Device device) {
48+ torch::Tensor AttentionMask ::get_attn_mask (int64_t max_s,
49+ torch::Dtype dtype,
50+ torch::Device device) {
5251 update_attn_cache (dtype, device, max_s);
5352 return atten_mask_cache_.slice (0 , 0 , max_s).slice (1 , 0 , max_s);
5453}
5554
56- torch::Tensor AttentionMaskImpl ::gen_free_mask (int32_t q_len,
57- torch::Dtype dtype,
58- torch::Device device) {
55+ torch::Tensor AttentionMask ::gen_free_mask (int32_t q_len,
56+ torch::Dtype dtype,
57+ torch::Device device) {
5958 float pre_mask_factor = -10000 .0f ;
6059 if (dtype == torch::kBFloat16 ) {
6160 pre_mask_factor = 1 .0f ;
@@ -68,11 +67,11 @@ torch::Tensor AttentionMaskImpl::gen_free_mask(int32_t q_len,
6867 return mask_free;
6968}
7069
71- torch::Tensor AttentionMaskImpl ::gen_append_mask (int32_t q_len,
72- int32_t kv_len,
73- int32_t max_kv_len,
74- torch::Dtype dtype,
75- torch::Device device) {
70+ torch::Tensor AttentionMask ::gen_append_mask (int32_t q_len,
71+ int32_t kv_len,
72+ int32_t max_kv_len,
73+ torch::Dtype dtype,
74+ torch::Device device) {
7675 int diagonal = kv_len - q_len;
7776 auto options = torch::TensorOptions ().dtype (torch::kBool ).device (device);
7877 auto bias = torch::tril (torch::ones ({q_len, max_kv_len}, options), diagonal);
@@ -84,9 +83,9 @@ torch::Tensor AttentionMaskImpl::gen_append_mask(int32_t q_len,
8483 return mask;
8584}
8685
87- void AttentionMaskImpl ::update_attn_cache (torch::Dtype dtype,
88- torch::Device device,
89- int64_t seqlen) {
86+ void AttentionMask ::update_attn_cache (torch::Dtype dtype,
87+ torch::Device device,
88+ int64_t seqlen) {
9089 if (seqlen > seq_len_cached_ || atten_mask_cache_.dtype () != dtype) {
9190 seq_len_cached_ = seqlen;
9291
0 commit comments