@@ -85,7 +85,6 @@ class Glm4ModelImpl : public LlmModelImplBase<Glm4DecoderLayer> {
8585 } else {
8686 h = embed_tokens_ (tokens, 0 );
8787 }
88-
8988 auto target_cos_sin = atb_pos_emb_ (cos_sin_, positions, 0 );
9089 auto target_cos_sin_chunks = target_cos_sin.chunk (/* chunks=*/ 2 , /* dim=*/ -1 );
9190 auto cos_pos = target_cos_sin_chunks[0 ].contiguous ();
@@ -98,7 +97,7 @@ class Glm4ModelImpl : public LlmModelImplBase<Glm4DecoderLayer> {
9897 for (int dim_idx = 1 ; dim_idx <= 2 ; ++dim_idx) {
9998 int64_t offset = dim_idx;
10099 int64_t section_len = mrope_section_[dim_idx];
101- int64_t length = section_len * 3 ;
100+ int64_t length = section_len * 2 ;
102101 auto idx_first_half = torch::arange (offset, length, 3 , torch::kLong );
103102 auto idx_second_half = torch::arange (offset, length, 3 , torch::kLong );
104103 auto idx_tensor =
@@ -114,7 +113,8 @@ class Glm4ModelImpl : public LlmModelImplBase<Glm4DecoderLayer> {
114113 sin_pos = apply (sin_pos.reshape (
115114 {positions.sizes ().front (), -1 , sin_pos.sizes ().back ()}));
116115 }
117-
116+ cos_pos = cos_pos.reshape ({-1 , cos_pos.sizes ().back () /2 , 2 });
117+ sin_pos = sin_pos.reshape ({-1 , sin_pos.sizes ().back () /2 , 2 });
118118 torch::Tensor attn_mask;
119119 if (FLAGS_enable_chunked_prefill) {
120120 int max_kv_seq = input_params.kv_max_seq_len ;
0 commit comments