Add Ulysses SP support for FLA gated-delta context parallelism#8114
Add Ulysses SP support for FLA gated-delta context parallelism#8114xylian86 wants to merge 2 commits into
Conversation
6c7269a to
b49edd2
Compare
Signed-off-by: Xinyu Lian <lian7@illinois.edu>
Signed-off-by: Xinyu Lian <lian7@illinois.edu>
368c122 to
067f377
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 368c1221f5
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| unsupported_kwargs = sorted(set(kwargs) - set(_IGNORED_LINEAR_ATTENTION_FORWARD_KWARGS)) | ||
| if args or unsupported_kwargs: | ||
| raise RuntimeError("Linear attention CP support received unsupported extra forward arguments: " | ||
| f"args={len(args)}, kwargs={unsupported_kwargs}") |
There was a problem hiding this comment.
Accept packed sequence metadata in GDN CP
For packed SFT batches, Transformers' Qwen3.5 GatedDeltaNet forwards cu_seq_lens_q in kwargs and the original implementation passes it into chunk_gated_delta_rule to reset recurrent state at document boundaries. This wrapper rejects that key as unsupported before building cp_context, so packed batches either fail here instead of training or lose the boundary metadata if callers omit it; please consume the packed-sequence kwargs (or derive equivalent cu-seqlens) rather than treating them as unsupported.
Useful? React with 👍 / 👎.
| sp_group, sp_world_size, _ = _get_sequence_parallel_info() | ||
| if sp_world_size == 1: | ||
| return _call_original_linear_attention_forward(self, hidden_states, cache_params, attention_mask, position_ids, | ||
| *args, **kwargs) |
There was a problem hiding this comment.
Honor eval disablement for linear CP
When register_with_transformers(..., disable_in_eval=True) is used, the full-attention wrapper bypasses sequence-parallel collectives in eval, but patched gated-delta layers still enter this CP path whenever the SP group exists. In HF Trainer eval flows that rely on disable_in_eval because eval batches are not SP-sharded like training batches, these linear layers can still run FLA CP collectives or compute on the wrong distribution; carry the disable flag into this patch and call the original forward when not self.training.
Useful? React with 👍 / 👎.
Bringing Ulysses Sequence Parallelism to Hybrid Linear Attention Models
Ulysses sequence parallelism is a practical way to train long-context. For dense attention, the boundary is relatively clean: shard the sequence, all-to-all around the attention kernel, and aggregate losses across sequence-parallel ranks.
Hybrid models make this harder. Qwen3.5/3.6/Kimi-style models mix full attention layers with gated-delta linear attention layers. Those linear attention layers are not stateless attention kernels. They depend on causal convolution state, recurrent state, packed-sequence metadata, and kernel-specific assumptions about how chunks are laid out across ranks.
This PR adds DeepSpeed Ulysses support for those gated-delta linear attention layers.
Why this was not a one-line integration
The dense-attention Ulysses path wraps the attention function. That works because the dense attention function receives Q/K/V and returns the context tensor. Linear attention does not expose the same clean boundary.
For gated-delta layers, the causal convolution before the recurrent rule also needs context from neighboring sequence shards. The recurrent kernel itself needs a context-parallel view of the full logical sequence. If either piece treats the local shard as an independent sequence, the output can still be finite, but it is no longer the same model.
The dependency story also matters. Linear attention context parallelism depends on FLA APIs such as:
Implementation details
The implementation stays inside the existing
ulysses_sp.pyintegration instead of adding a model-specific file.At registration time, DeepSpeed checks whether the model may contain linear attention. If not, dense-attention-only models keep the existing path. If the model does use linear attention, DeepSpeed imports the Transformers modeling module, finds supported gated-delta layer classes, validates the FLA CP APIs, and patches only those gated-delta forwards.
The most important implementation detail is how we preserve global sequence metadata after sequence sharding. Ulysses splits activations across ranks, but the model is still computing one logical sequence. Any metadata that describes token order, document boundaries, or packed-sequence resets must continue to describe the original full sequence, not the local shard.
Data sharding and position IDs
Ulysses sharding happens in the dataloader adapter, not inside the dataset. The dataset or collator must return full-sequence tensors:
For a normal non-packed sequence,
position_idsare simply:For packed sequences,
position_idsreset at document boundaries:The dataloader adapter first gathers one full batch from each SP rank, then slices every tensor on the sequence dimension. After sharding, each rank holds only its local token span, but the local
position_idsstill contain the original global/document-relative positions for that span.That requirement is intentional. If
position_idsare generated after sharding, every rank would create a local sequence like[0, 1, ..., local_seq_len - 1]. When those local positions are gathered later, the combined sequence would look like it has artificial document boundaries at every SP boundary.For example, assume a single non-packed sequence of length 16 and
SP=4.The correct full-sequence metadata is:
If each rank generates
position_idslocally after sharding, we get:That second sequence is not a length-16 sequence anymore. It looks like four packed documents of length 4. This is not a cosmetic metadata bug; it changes the function being computed.
For dense attention, packed-sequence logic may use the reset points to prevent tokens from attending across document boundaries. The incorrect metadata would prevent rank 1 tokens from attending to rank 0 tokens, rank 2 tokens from attending to rank 1 tokens, and so on.
For linear attention, the effect is even more direct. Gated-delta attention is recurrent. A reset at an SP boundary tells the recurrent rule to drop the previous state exactly where it should have carried state forward. The causal convolution has the same issue: tokens at the beginning of rank 1 need convolution context from the end of rank 0. If the SP boundary is treated as a document boundary, the convolution is padded/reset instead of using the true left context.
The opposite mistake is also possible. In packed training, if we lose real
position_idsreset points, the recurrent state can leak from one packed document into the next. So the invariant is:Dense attention path
For full attention layers, the existing Ulysses wrapper all-to-all redistributes Q/K/V across the sequence and head dimensions. Attention backends still need correct global sequence metadata after that redistribution.
So before calling the underlying attention implementation, DeepSpeed all-gathers the local
position_idsacross the sequence-parallel group:The attention wrapper then passes
full_position_idsto the backend. This is important for backends that useposition_idsfor packed-sequence detection or causal masking.We use
position_idsinstead of a full attention mask because the scale is completely different. For a 1M-token sequence, a 4D causal mask is quadratic in sequence length and is not a viable metadata representation.position_idsare only[batch, seq_len], so the all-gather is cheap relative to gathering activations or materializing attention masks.In other words, the dense attention path does not gather the whole sequence worth of hidden states just to recover metadata. It gathers the small metadata tensor that tells the backend what the full logical sequence is.
Linear attention CP context
Linear attention does not go through the dense attention all-to-all boundary. The patched gated-delta layer receives local hidden states:
Before running the FLA kernels, DeepSpeed builds one CP context for the full logical sequence. The reason is that the FLA kernels need to know how local chunks connect to each other. They need to know which neighboring rank owns the previous tokens, where packed documents start and end, and whether recurrent/convolution state should carry across a boundary or reset.
If
position_idsare available, the linear attention path performs the same SP-group all-gather:It then converts the gathered
position_idsinto global cumulative sequence lengths. The conversion is based on reset points, whereposition_id == 0marks the start of a packed document:This
cu_seqlensrepresentation is what the kernels need. It says:The kernels can then carry state within each document and reset state between documents. That is the same semantic boundary the unsharded model would see.
Those
cu_seqlensare passed to FLA's CP context builder:The GPU
cu_seqlensare used by kernels. The CPU copy is kept because FLA's CP helper also uses host-side metadata when building the context.At a high level, the CP context is the object that makes a local shard part of a larger logical sequence. It carries the SP process group, the true sequence/document boundaries, and the convolution kernel size. FLA uses that information to exchange or account for boundary state between ranks.
For a simple non-packed path where the layer does not receive
position_ids, DeepSpeed falls back to a single contiguous sequence:That is correct for the non-packed contiguous validation path because there is exactly one logical document and no reset point inside it. It is not enough for packed training. Packed training must preserve
position_ids, otherwise the implementation cannot distinguish real document boundaries from sequence-parallel shard boundaries.The same CP context is then passed to both:
Running with Transformers Trainer
Here is a minimal Qwen3.5 Trainer entrypoint:
Launch it with one sequence-parallel group of 8 GPUs:
For the equivalence experiment below, we used
gradient_accumulation_steps=sp_size. That matches the token budget betweenDP=8, SP=1, GAS=1andDP=1, SP=8, GAS=8.Numerical validation
We followed the same validation principle as the Ulysses Trainer path: compare canonical token-normalized NLL under matched token budget.
E2E Results: Enabling SFT on Qwen3.5-27B with 1M Sequence Length on 4 GPUs