Skip to content

Add Ulysses SP support for FLA gated-delta context parallelism#8114

Open
xylian86 wants to merge 2 commits into
masterfrom
xlian/linear-recurrent-cp
Open

Add Ulysses SP support for FLA gated-delta context parallelism#8114
xylian86 wants to merge 2 commits into
masterfrom
xlian/linear-recurrent-cp

Conversation

@xylian86

@xylian86 xylian86 commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Bringing Ulysses Sequence Parallelism to Hybrid Linear Attention Models

Ulysses sequence parallelism is a practical way to train long-context. For dense attention, the boundary is relatively clean: shard the sequence, all-to-all around the attention kernel, and aggregate losses across sequence-parallel ranks.

Hybrid models make this harder. Qwen3.5/3.6/Kimi-style models mix full attention layers with gated-delta linear attention layers. Those linear attention layers are not stateless attention kernels. They depend on causal convolution state, recurrent state, packed-sequence metadata, and kernel-specific assumptions about how chunks are laid out across ranks.

This PR adds DeepSpeed Ulysses support for those gated-delta linear attention layers.

Why this was not a one-line integration

The dense-attention Ulysses path wraps the attention function. That works because the dense attention function receives Q/K/V and returns the context tensor. Linear attention does not expose the same clean boundary.

For gated-delta layers, the causal convolution before the recurrent rule also needs context from neighboring sequence shards. The recurrent kernel itself needs a context-parallel view of the full logical sequence. If either piece treats the local shard as an independent sequence, the output can still be finite, but it is no longer the same model.

The dependency story also matters. Linear attention context parallelism depends on FLA APIs such as:

fla.ops.cp.build_cp_context
fla.ops.cp.FLACPContext
fla.modules.conv.causal_conv1d(..., cp_context=...)
chunk_gated_delta_rule(..., cp_context=...)

Implementation details

The implementation stays inside the existing ulysses_sp.py integration instead of adding a model-specific file.

At registration time, DeepSpeed checks whether the model may contain linear attention. If not, dense-attention-only models keep the existing path. If the model does use linear attention, DeepSpeed imports the Transformers modeling module, finds supported gated-delta layer classes, validates the FLA CP APIs, and patches only those gated-delta forwards.

The most important implementation detail is how we preserve global sequence metadata after sequence sharding. Ulysses splits activations across ranks, but the model is still computing one logical sequence. Any metadata that describes token order, document boundaries, or packed-sequence resets must continue to describe the original full sequence, not the local shard.

Data sharding and position IDs

Ulysses sharding happens in the dataloader adapter, not inside the dataset. The dataset or collator must return full-sequence tensors:

input_ids:    [batch, global_seq_len]
position_ids: [batch, global_seq_len]
labels:       [batch, global_seq_len]

For a normal non-packed sequence, position_ids are simply:

[0, 1, 2, ..., global_seq_len - 1]

For packed sequences, position_ids reset at document boundaries:

[0, 1, 2, 0, 1, 0, 1, 2]

The dataloader adapter first gathers one full batch from each SP rank, then slices every tensor on the sequence dimension. After sharding, each rank holds only its local token span, but the local position_ids still contain the original global/document-relative positions for that span.

That requirement is intentional. If position_ids are generated after sharding, every rank would create a local sequence like [0, 1, ..., local_seq_len - 1]. When those local positions are gathered later, the combined sequence would look like it has artificial document boundaries at every SP boundary.

For example, assume a single non-packed sequence of length 16 and SP=4.

The correct full-sequence metadata is:

global position_ids:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

after sequence sharding:
rank 0: [0,  1,  2,  3]
rank 1: [4,  5,  6,  7]
rank 2: [8,  9, 10, 11]
rank 3: [12, 13, 14, 15]

If each rank generates position_ids locally after sharding, we get:

rank 0: [0, 1, 2, 3]
rank 1: [0, 1, 2, 3]
rank 2: [0, 1, 2, 3]
rank 3: [0, 1, 2, 3]

all-gathered metadata:
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]

That second sequence is not a length-16 sequence anymore. It looks like four packed documents of length 4. This is not a cosmetic metadata bug; it changes the function being computed.

For dense attention, packed-sequence logic may use the reset points to prevent tokens from attending across document boundaries. The incorrect metadata would prevent rank 1 tokens from attending to rank 0 tokens, rank 2 tokens from attending to rank 1 tokens, and so on.

For linear attention, the effect is even more direct. Gated-delta attention is recurrent. A reset at an SP boundary tells the recurrent rule to drop the previous state exactly where it should have carried state forward. The causal convolution has the same issue: tokens at the beginning of rank 1 need convolution context from the end of rank 0. If the SP boundary is treated as a document boundary, the convolution is padded/reset instead of using the true left context.

The opposite mistake is also possible. In packed training, if we lose real position_ids reset points, the recurrent state can leak from one packed document into the next. So the invariant is:

SP boundaries are not document boundaries.
Packed-document boundaries are document boundaries.
position_ids are how we tell the difference.

Dense attention path

For full attention layers, the existing Ulysses wrapper all-to-all redistributes Q/K/V across the sequence and head dimensions. Attention backends still need correct global sequence metadata after that redistribution.

So before calling the underlying attention implementation, DeepSpeed all-gathers the local position_ids across the sequence-parallel group:

position_ids_list = [torch.empty_like(local_position_ids) for _ in range(sp_world_size)]
dist.all_gather(position_ids_list, local_position_ids, group=sp_group)
full_position_ids = torch.cat(position_ids_list, dim=1)

The attention wrapper then passes full_position_ids to the backend. This is important for backends that use position_ids for packed-sequence detection or causal masking.

We use position_ids instead of a full attention mask because the scale is completely different. For a 1M-token sequence, a 4D causal mask is quadratic in sequence length and is not a viable metadata representation. position_ids are only [batch, seq_len], so the all-gather is cheap relative to gathering activations or materializing attention masks.

In other words, the dense attention path does not gather the whole sequence worth of hidden states just to recover metadata. It gathers the small metadata tensor that tells the backend what the full logical sequence is.

Linear attention CP context

Linear attention does not go through the dense attention all-to-all boundary. The patched gated-delta layer receives local hidden states:

hidden_states: [micro_batch, local_seq_len, hidden_size]

Before running the FLA kernels, DeepSpeed builds one CP context for the full logical sequence. The reason is that the FLA kernels need to know how local chunks connect to each other. They need to know which neighboring rank owns the previous tokens, where packed documents start and end, and whether recurrent/convolution state should carry across a boundary or reset.

If position_ids are available, the linear attention path performs the same SP-group all-gather:

position_id_shards = [torch.empty_like(position_ids) for _ in range(sp_world_size)]
dist.all_gather(position_id_shards, position_ids.contiguous(), group=sp_group)
full_position_ids = torch.cat(position_id_shards, dim=1)

It then converts the gathered position_ids into global cumulative sequence lengths. The conversion is based on reset points, where position_id == 0 marks the start of a packed document:

position_ids = [0, 1, 2, 0, 1, 0, 1, 2]
cu_seqlens   = [0, 3, 5, 8]

This cu_seqlens representation is what the kernels need. It says:

document 0 occupies tokens [0, 3)
document 1 occupies tokens [3, 5)
document 2 occupies tokens [5, 8)

The kernels can then carry state within each document and reset state between documents. That is the same semantic boundary the unsharded model would see.

Those cu_seqlens are passed to FLA's CP context builder:

cp_context = build_cp_context(
    cu_seqlens=global_cu_seqlens,
    cu_seqlens_cpu=global_cu_seqlens.cpu(),
    group=sp_group,
    conv1d_kernel_size=conv_kernel_size,
)

The GPU cu_seqlens are used by kernels. The CPU copy is kept because FLA's CP helper also uses host-side metadata when building the context.

At a high level, the CP context is the object that makes a local shard part of a larger logical sequence. It carries the SP process group, the true sequence/document boundaries, and the convolution kernel size. FLA uses that information to exchange or account for boundary state between ranks.

For a simple non-packed path where the layer does not receive position_ids, DeepSpeed falls back to a single contiguous sequence:

cu_seqlens = [0, sp_world_size * local_seq_len]

That is correct for the non-packed contiguous validation path because there is exactly one logical document and no reset point inside it. It is not enough for packed training. Packed training must preserve position_ids, otherwise the implementation cannot distinguish real document boundaries from sequence-parallel shard boundaries.

The same CP context is then passed to both:

causal_conv1d(..., cp_context=cp_context)
chunk_gated_delta_rule(..., cp_context=cp_context)

Running with Transformers Trainer

Here is a minimal Qwen3.5 Trainer entrypoint:

# train_qwen35_ulysses_trainer.py
import os

import torch
from accelerate import ParallelismConfig
from torch.utils.data import Dataset
from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments


class RandomTokenDataset(Dataset):
    def __init__(self, vocab_size, seq_len, num_samples, seed=20260702):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        self.seed = seed

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        generator = torch.Generator(device="cpu").manual_seed(self.seed + idx)
        input_ids = torch.randint(
            low=0,
            high=self.vocab_size,
            size=(self.seq_len,),
            generator=generator,
            dtype=torch.long,
        )
        position_ids = torch.arange(self.seq_len, dtype=torch.long)
        return {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "labels": input_ids.clone(),
        }


def collate_fn(features):
    return {
        key: torch.stack([feature[key] for feature in features], dim=0)
        for key in features[0]
    }


def main():
    model_id = os.environ.get("MODEL_ID", "Qwen/Qwen3.5-4B")
    sp_size = int(os.environ.get("SP_SIZE", "8"))
    seq_len = int(os.environ.get("SEQ_LEN", "256"))
    max_steps = int(os.environ.get("MAX_STEPS", "20"))
    attn_impl = os.environ.get("ATTN_IMPL", "flash_attention_2")

    config = AutoConfig.from_pretrained(model_id)
    text_config = config.get_text_config() if hasattr(config, "get_text_config") else config

    dataset = RandomTokenDataset(
        vocab_size=text_config.vocab_size,
        seq_len=seq_len,
        num_samples=max_steps * sp_size,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        dtype=torch.bfloat16,
        attn_implementation=attn_impl,
    )
    model.config.use_cache = False

    deepspeed_config = {
        "train_micro_batch_size_per_gpu": 1,
        "gradient_accumulation_steps": "auto",
        "bf16": {"enabled": True},
        "zero_optimization": {"stage": 3},
        "optimizer": {
            "type": "AdamW",
            "params": {"lr": "auto"},
        },
        "sequence_parallel_size": sp_size,
    }

    args = TrainingArguments(
        output_dir="outputs/qwen35-ulysses-sp",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=sp_size,
        max_steps=max_steps,
        learning_rate=1e-5,
        bf16=True,
        logging_steps=1,
        save_steps=0,
        remove_unused_columns=False,
        report_to=[],
        deepspeed=deepspeed_config,
        parallelism_config=ParallelismConfig(
            sp_size=sp_size,
            sp_backend="deepspeed",
        ),
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset,
        data_collator=collate_fn,
    )
    trainer.train()


if __name__ == "__main__":
    main()

Launch it with one sequence-parallel group of 8 GPUs:

PYTHONPATH=/path/to/DeepSpeed \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
torchrun --standalone --nproc_per_node=8 train_qwen35_ulysses_trainer.py

For the equivalence experiment below, we used gradient_accumulation_steps=sp_size. That matches the token budget between DP=8, SP=1, GAS=1 and DP=1, SP=8, GAS=8.

Numerical validation

We followed the same validation principle as the Ulysses Trainer path: compare canonical token-normalized NLL under matched token budget.

Model: Qwen/Qwen3.5-4B
GPUs: 8x NVIDIA B200
dtype: bf16
seq_len: 256
steps: 20

Baseline: DP=8, SP=1, GAS=1
Ulysses:  DP=1, SP=8, GAS=8
Metric:   canonical token-normalized NLL
image
Mean abs diff: 0.00078092
Max abs diff:  0.00190544
step  DP loss      SP loss      abs diff
0     14.35010815  14.35150623  0.00139809
1     14.29790020  14.29677200  0.00112820
2     14.39342308  14.39313984  0.00028324
3     14.41278839  14.41088295  0.00190544
4     14.48497009  14.48612881  0.00115871
5     14.37709522  14.37749767  0.00040245
6     14.34707355  14.34603500  0.00103855
7     14.42868519  14.42852592  0.00015926
8     14.34224033  14.34336567  0.00112534
9     14.33056068  14.33194923  0.00138855
10    14.26563835  14.26433563  0.00130272
11    14.42535973  14.42444229  0.00091743
12    14.36367416  14.36334801  0.00032616
13    14.41588306  14.41629028  0.00040722
14    14.40611553  14.40610790  0.00000763
15    14.36241341  14.36316681  0.00075340
16    14.36795616  14.36826038  0.00030422
17    14.43968582  14.44034576  0.00065994
18    14.41673470  14.41668892  0.00004578
19    14.51559925  14.51650524  0.00090599

E2E Results: Enabling SFT on Qwen3.5-27B with 1M Sequence Length on 4 GPUs

image

@xylian86 xylian86 force-pushed the xlian/linear-recurrent-cp branch 3 times, most recently from 6c7269a to b49edd2 Compare July 2, 2026 07:46
@xylian86 xylian86 marked this pull request as ready for review July 2, 2026 07:47
xylian86 added 2 commits July 2, 2026 07:49
Signed-off-by: Xinyu Lian <lian7@illinois.edu>
Signed-off-by: Xinyu Lian <lian7@illinois.edu>
@xylian86 xylian86 force-pushed the xlian/linear-recurrent-cp branch from 368c122 to 067f377 Compare July 2, 2026 07:50

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 368c1221f5

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +931 to +934
unsupported_kwargs = sorted(set(kwargs) - set(_IGNORED_LINEAR_ATTENTION_FORWARD_KWARGS))
if args or unsupported_kwargs:
raise RuntimeError("Linear attention CP support received unsupported extra forward arguments: "
f"args={len(args)}, kwargs={unsupported_kwargs}")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accept packed sequence metadata in GDN CP

For packed SFT batches, Transformers' Qwen3.5 GatedDeltaNet forwards cu_seq_lens_q in kwargs and the original implementation passes it into chunk_gated_delta_rule to reset recurrent state at document boundaries. This wrapper rejects that key as unsupported before building cp_context, so packed batches either fail here instead of training or lose the boundary metadata if callers omit it; please consume the packed-sequence kwargs (or derive equivalent cu-seqlens) rather than treating them as unsupported.

Useful? React with 👍 / 👎.

Comment on lines +919 to +922
sp_group, sp_world_size, _ = _get_sequence_parallel_info()
if sp_world_size == 1:
return _call_original_linear_attention_forward(self, hidden_states, cache_params, attention_mask, position_ids,
*args, **kwargs)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor eval disablement for linear CP

When register_with_transformers(..., disable_in_eval=True) is used, the full-attention wrapper bypasses sequence-parallel collectives in eval, but patched gated-delta layers still enter this CP path whenever the SP group exists. In HF Trainer eval flows that rely on disable_in_eval because eval batches are not SP-sharded like training batches, these linear layers can still run FLA CP collectives or compute on the wrong distribution; carry the disable flag into this patch and call the original forward when not self.training.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant