@@ -65,7 +65,6 @@ def forward(
6565 rotary_pos_emb_cos : torch .Tensor ,
6666 rotary_pos_emb_sin : torch .Tensor ,
6767 max_seqlen : torch .Tensor ,
68- seqlens : torch .Tensor ,
6968 ) -> torch .Tensor :
7069 # [s, b, c] --> [s, b, head * 3 * head_dim]
7170 x , _ = self .qkv (x )
@@ -141,15 +140,13 @@ def forward(
141140 rotary_pos_emb_cos : torch .Tensor ,
142141 rotary_pos_emb_sin : torch .Tensor ,
143142 max_seqlen : int | None = None , # Only used for Flash Attention
144- seqlens : list [int ] | None = None , # Only used for xFormers
145143 ) -> torch .Tensor :
146144 x = x + self .attn (
147145 self .norm1 (x ),
148146 cu_seqlens = cu_seqlens ,
149147 rotary_pos_emb_cos = rotary_pos_emb_cos ,
150148 rotary_pos_emb_sin = rotary_pos_emb_sin ,
151149 max_seqlen = max_seqlen ,
152- seqlens = seqlens ,
153150 )
154151 x = x + self .mlp (self .norm2 (x ))
155152 return x
@@ -198,7 +195,6 @@ def __init__(
198195 head_size = head_dim ,
199196 rotary_dim = head_dim // 2 ,
200197 max_position = 8192 ,
201- base = 10000.0 ,
202198 is_neox_style = True ,
203199 )
204200
@@ -300,15 +296,14 @@ def forward(
300296 x = x .unsqueeze (1 )
301297
302298 # pre-compute seqlens for attn mask to reduce cuMemcpy operations
303- max_seqlen , seqlens = self .compute_attn_mask_seqlen (cu_seqlens )
299+ max_seqlen = self .compute_attn_mask_seqlen (cu_seqlens )
304300 for blk in self .blocks :
305301 x = blk (
306302 x ,
307303 cu_seqlens = cu_seqlens ,
308304 rotary_pos_emb_cos = rotary_pos_emb_cos ,
309305 rotary_pos_emb_sin = rotary_pos_emb_sin ,
310306 max_seqlen = max_seqlen ,
311- seqlens = seqlens ,
312307 )
313308
314309 # adapter
@@ -326,15 +321,13 @@ def forward(
326321 rotary_pos_emb_cos : torch .Tensor ,
327322 rotary_pos_emb_sin : torch .Tensor ,
328323 max_seqlen : torch .Tensor , # Only used for Flash Attention
329- seqlens : torch .Tensor , # Only used for xFormers
330324 ) -> torch .Tensor :
331325 x_attn = self .attn (
332326 self .norm1 (x ),
333327 cu_seqlens = cu_seqlens ,
334328 rotary_pos_emb_cos = rotary_pos_emb_cos ,
335329 rotary_pos_emb_sin = rotary_pos_emb_sin ,
336330 max_seqlen = max_seqlen ,
337- seqlens = seqlens ,
338331 )
339332 x_fused_norm , residual = self .norm2 (x , residual = x_attn )
340333 x = residual + self .mlp (x_fused_norm )
@@ -552,10 +545,8 @@ def forward(
552545
553546 # transformers
554547 # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
555- max_seqlen_full , seqlens_full = self .compute_attn_mask_seqlen (
556- cu_seqlens )
557- max_seqlen_window , seqlens_window = self .compute_attn_mask_seqlen (
558- cu_window_seqlens )
548+ max_seqlen_full = self .compute_attn_mask_seqlen (cu_seqlens )
549+ max_seqlen_window = self .compute_attn_mask_seqlen (cu_window_seqlens )
559550
560551 cu_seqlens = cu_seqlens .to ( # type: ignore[attr-defined]
561552 device = self .device ,
@@ -586,19 +577,16 @@ def forward(
586577 if layer_num in self .fullatt_block_indexes :
587578 cu_seqlens_now = cu_seqlens
588579 max_seqlen_now = max_seqlen_full
589- seqlens_now = seqlens_full
590580 else :
591581 cu_seqlens_now = cu_window_seqlens
592582 max_seqlen_now = max_seqlen_window
593- seqlens_now = seqlens_window
594583
595584 hidden_states = blk (
596585 hidden_states ,
597586 cu_seqlens = cu_seqlens_now ,
598587 rotary_pos_emb_cos = rotary_pos_emb_cos ,
599588 rotary_pos_emb_sin = rotary_pos_emb_sin ,
600589 max_seqlen = max_seqlen_now ,
601- seqlens = seqlens_now ,
602590 )
603591
604592 # For Qwen2.5-VL-3B, float16 will overflow at last block
0 commit comments