From 44fcc76b8de29476c7a931f5732b520958791808 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Thu, 2 Jul 2026 04:18:22 +0000 Subject: [PATCH] Add Ulysses SP support for FLA gated-delta context parallelism Signed-off-by: Xinyu Lian --- .../runtime/sequence_parallel/ulysses_sp.py | 457 +++++++++++++++++- tests/unit/ulysses_alst/test_ulysses_sp_hf.py | 174 +++++++ 2 files changed, 627 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/sequence_parallel/ulysses_sp.py b/deepspeed/runtime/sequence_parallel/ulysses_sp.py index 413921c2090c..b679a12e03b6 100644 --- a/deepspeed/runtime/sequence_parallel/ulysses_sp.py +++ b/deepspeed/runtime/sequence_parallel/ulysses_sp.py @@ -37,10 +37,11 @@ from packaging import version from torch import Tensor from torch.utils.data import DataLoader -from typing import Any -from typing import Tuple +from typing import Any, Optional, Tuple import deepspeed.comm as dist -import importlib.metadata +from importlib import import_module +from importlib import metadata as importlib_metadata +import inspect import math import re import torch @@ -136,7 +137,7 @@ def __init__( self.local_kv_head_count = kv_head_count // self.world_size transformers_version_min = "4.51.3" - transformers_version_have = importlib.metadata.version("transformers") + transformers_version_have = importlib_metadata.version("transformers") if version.parse(transformers_version_have) < version.parse(transformers_version_min): raise ValueError( f"transformers>={transformers_version_min} is required, but you have transformers=={transformers_version_have}" @@ -558,9 +559,457 @@ def uattn_wrapper( # This is what we called "Being John Malkovich". ALL_ATTENTION_FUNCTIONS[core_attn_implementation] = uattn_wrapper + _register_linear_attention_cp(hf_model_config, arch_cfg) + return mpu +# ---------------------------------------------------------------------------- +# Linear attention context-parallel support +# ---------------------------------------------------------------------------- + +# FLA 0.4.2 is the first release with fla.ops.cp and cp_context support in +# causal_conv1d/chunk_gated_delta_rule. +_LINEAR_ATTENTION_CP_MIN_FLA_VERSION = "0.4.2" +_LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS = {} + +_LINEAR_ATTENTION_LAYER_TYPE_MARKERS = ( + "linear_attention", + "linear_attn", + "gated_delta", + "gated_deltanet", + "deltanet", + "kda", + "delta_attention", +) +_LINEAR_ATTENTION_MODEL_TYPE_MARKERS = ( + "qwen3_5", + "qwen3_6", + "kimi", +) +_GATED_DELTA_CLASS_NAME_MARKERS = ("gateddeltanet", "gateddeltarule") +_IGNORED_LINEAR_ATTENTION_FORWARD_KWARGS = ( + "output_attentions", + "output_hidden_states", + "return_dict", + "use_cache", +) + + +def _callable_accepts_keyword(fn, keyword: str) -> bool: + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return True + if keyword in sig.parameters: + return True + return any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()) + + +def _gated_delta_state_layout_kwargs(fn): + try: + parameters = inspect.signature(fn).parameters + except (TypeError, ValueError): + return {"transpose_state_layout": True} + if "state_v_first" in parameters: + return {"state_v_first": True} + return {"transpose_state_layout": True} + + +def _get_installed_fla_versions(): + versions = {} + for dist_name in ("flash-linear-attention", "fla-core"): + try: + versions[dist_name] = importlib_metadata.version(dist_name) + except importlib_metadata.PackageNotFoundError: + continue + return versions + + +def _load_linear_attention_cp_ops(): + installed_versions = _get_installed_fla_versions() + parsed_versions = [(package_name, version.parse(fla_version)) + for package_name, fla_version in installed_versions.items()] + + if parsed_versions and all(parsed_version < version.parse(_LINEAR_ATTENTION_CP_MIN_FLA_VERSION) + for _, parsed_version in parsed_versions): + found = ", ".join(f"{name}={installed_versions[name]}" for name, _ in parsed_versions) + raise ImportError("DeepSpeed linear attention CP support requires " + f"flash-linear-attention/fla-core >= {_LINEAR_ATTENTION_CP_MIN_FLA_VERSION}; " + f"found {found}.") + + try: + fla_cp = import_module("fla.ops.cp") + fla_conv = import_module("fla.modules.conv") + except ImportError as exc: + raise ImportError("DeepSpeed linear attention CP support requires an FLA build with " + "`fla.ops.cp` and CP causal convolution support. Install " + f"flash-linear-attention/fla-core >= {_LINEAR_ATTENTION_CP_MIN_FLA_VERSION}.") from exc + + build_cp_context = getattr(fla_cp, "build_cp_context", None) + flacp_context = getattr(fla_cp, "FLACPContext", None) + causal_conv1d = getattr(fla_conv, "causal_conv1d", None) + missing_symbols = [] + if build_cp_context is None: + missing_symbols.append("fla.ops.cp.build_cp_context") + if flacp_context is None: + missing_symbols.append("fla.ops.cp.FLACPContext") + if causal_conv1d is None: + missing_symbols.append("fla.modules.conv.causal_conv1d") + if missing_symbols: + raise ImportError("DeepSpeed linear attention CP support requires missing FLA symbol(s): " + f"{missing_symbols}. Install flash-linear-attention/fla-core >= " + f"{_LINEAR_ATTENTION_CP_MIN_FLA_VERSION}.") + + if not _callable_accepts_keyword(causal_conv1d, "cp_context"): + raise ImportError("Installed FLA `causal_conv1d` does not accept `cp_context`; " + "install a build with context-parallel convolution support.") + + return build_cp_context, causal_conv1d + + +def _get_sequence_parallel_info(): + import deepspeed.runtime.sequence_parallel.parallel_state_sp as mpu + + if getattr(mpu, "_SEQUENCE_PARALLEL_GROUP", None) is None: + return None, 1, 0 + return ( + mpu.get_sequence_parallel_group(), + mpu.get_sequence_parallel_world_size(), + mpu.get_sequence_parallel_rank(), + ) + + +def _iter_config_values(configs, attr_name: str): + for cfg in configs: + value = getattr(cfg, attr_name, None) + if value is None: + continue + if isinstance(value, str): + yield value + elif isinstance(value, (list, tuple, set)): + yield from value + else: + yield value + + +def _model_uses_linear_attention(hf_model_config, arch_cfg) -> bool: + configs = (hf_model_config, arch_cfg) + for layer_type in _iter_config_values(configs, "layer_types"): + layer_type = str(layer_type).lower() + if any(marker in layer_type for marker in _LINEAR_ATTENTION_LAYER_TYPE_MARKERS): + return True + + for cfg in configs: + model_type = str(getattr(cfg, "model_type", "") or "").lower() + if any(marker in model_type for marker in _LINEAR_ATTENTION_MODEL_TYPE_MARKERS): + return True + + return False + + +def _model_type_module_candidates(model_type: str): + model_type = (model_type or "").strip() + if not model_type: + return + seen = set() + model_types = [model_type] + if model_type.endswith("_text"): + model_types.append(model_type[:-len("_text")]) + for candidate in model_types: + if candidate and candidate not in seen: + seen.add(candidate) + yield f"transformers.models.{candidate}.modeling_{candidate}" + + +def _modeling_module_candidates(hf_model_config, arch_cfg): + seen = set() + for cfg in (arch_cfg, hf_model_config): + for module_name in _model_type_module_candidates(str(getattr(cfg, "model_type", "") or "")): + if module_name not in seen: + seen.add(module_name) + yield module_name + + config_module = type(cfg).__module__ + if ".configuration_" in config_module: + module_name = config_module.replace(".configuration_", ".modeling_") + if module_name not in seen: + seen.add(module_name) + yield module_name + + +def _import_first_existing_module(module_names): + import_errors = [] + for module_name in module_names: + try: + return import_module(module_name) + except ImportError as exc: + import_errors.append((module_name, exc)) + if import_errors: + logger.warning_once("Unable to import a Transformers modeling module for linear attention CP support. " + f"Tried: {[name for name, _ in import_errors]}") + return None + + +def _is_gated_delta_class(cls) -> bool: + if not issubclass(cls, torch.nn.Module): + return False + class_name = cls.__name__.lower().replace("_", "") + return any(marker in class_name for marker in _GATED_DELTA_CLASS_NAME_MARKERS) + + +def _iter_gated_delta_classes(module): + for _, cls in inspect.getmembers(module, inspect.isclass): + if _is_gated_delta_class(cls): + yield cls + + +def _install_gated_delta_cp_forward(cls) -> bool: + if cls.forward is _gated_delta_cp_forward: + return False + if cls in _LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS: + raise RuntimeError(f"{cls.__name__}.forward was already patched by linear attention CP support.") + + _LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS[cls] = cls.forward + cls.forward = _gated_delta_cp_forward + logger.info(f"[ulysses_sp] installed gated-delta CP forward for {cls.__module__}.{cls.__name__}") + return True + + +def _register_linear_attention_cp(hf_model_config, arch_cfg) -> int: + if not _model_uses_linear_attention(hf_model_config, arch_cfg): + return 0 + + module = _import_first_existing_module(_modeling_module_candidates(hf_model_config, arch_cfg)) + if module is None: + raise RuntimeError("Ulysses SP detected a model that may contain linear attention layers, " + "but DeepSpeed could not import its Transformers modeling module. " + "Install a Transformers build that exposes the model implementation.") + + classes = list(_iter_gated_delta_classes(module)) + if not classes: + raise RuntimeError("Ulysses SP detected linear attention layers, but DeepSpeed does not yet " + f"have a supported gated-delta CP path for module {module.__name__}.") + + _load_linear_attention_cp_ops() + return sum(int(_install_gated_delta_cp_forward(cls)) for cls in classes) + + +def _position_ids_to_packed_cu_seqlens(position_ids: torch.LongTensor) -> torch.LongTensor: + if position_ids is None: + raise ValueError("position_ids must not be None") + if position_ids.ndim != 2 or position_ids.shape[0] != 1: + raise RuntimeError("Linear attention CP currently requires position_ids with shape [1, seq_len].") + + flat_position_ids = position_ids.reshape(-1) + sequence_starts = (flat_position_ids == 0).nonzero(as_tuple=False).flatten() + if sequence_starts.numel() == 0 or sequence_starts[0].item() != 0: + first_sequence_start = torch.zeros(1, device=position_ids.device, dtype=sequence_starts.dtype) + sequence_starts = torch.cat((first_sequence_start, sequence_starts)) + sequence_end = torch.tensor([flat_position_ids.numel()], device=position_ids.device, dtype=sequence_starts.dtype) + return torch.cat((sequence_starts, sequence_end)).to(dtype=torch.long) + + +def _build_linear_attention_cp_context( + position_ids, + sp_world_size, + sp_group, + conv_kernel_size, + local_seq_len, + device, +): + build_cp_context, _ = _load_linear_attention_cp_ops() + + if position_ids is None: + global_cu_seqlens_cpu = torch.tensor([0, sp_world_size * local_seq_len], dtype=torch.long) + global_cu_seqlens = global_cu_seqlens_cpu.to(device=device, non_blocking=True) + else: + position_id_shards = [torch.empty_like(position_ids) for _ in range(sp_world_size)] + dist.all_gather(position_id_shards, position_ids.contiguous(), group=sp_group) + full_position_ids = torch.cat(position_id_shards, dim=1) + global_cu_seqlens = _position_ids_to_packed_cu_seqlens(full_position_ids) + global_cu_seqlens_cpu = global_cu_seqlens.cpu() + + cp_context = build_cp_context( + cu_seqlens=global_cu_seqlens, + cu_seqlens_cpu=global_cu_seqlens_cpu, + group=sp_group, + conv1d_kernel_size=conv_kernel_size, + ) + return cp_context + + +def _apply_attention_mask_to_hidden_states(hidden_states, attention_mask): + if attention_mask is None: + return hidden_states + if attention_mask.ndim > hidden_states.ndim: + return hidden_states + + mask = attention_mask + if mask.dtype != torch.bool: + mask = mask > 0 + while mask.ndim < hidden_states.ndim: + mask = mask.unsqueeze(-1) + return hidden_states * mask.to(dtype=hidden_states.dtype, device=hidden_states.device) + + +def _call_original_linear_attention_forward( + self, + hidden_states, + cache_params=None, + attention_mask=None, + position_ids=None, + *args, + **kwargs, +): + original = _LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS.get(type(self)) + if original is None: + raise RuntimeError(f"Original forward for {type(self).__name__} is not registered.") + + call_kwargs = {} + for name, value in ( + ("cache_params", cache_params), + ("attention_mask", attention_mask), + ("position_ids", position_ids), + ): + if _callable_accepts_keyword(original, name): + call_kwargs[name] = value + call_kwargs.update(kwargs) + return original(self, hidden_states, *args, **call_kwargs) + + +def _validate_gated_delta_layer(module): + required_attrs = ( + "in_proj_qkv", + "in_proj_z", + "in_proj_b", + "in_proj_a", + "conv1d", + "activation", + "num_v_heads", + "num_k_heads", + "head_k_dim", + "head_v_dim", + "conv_kernel_size", + "A_log", + "dt_bias", + "chunk_gated_delta_rule", + "norm", + "out_proj", + ) + missing = [attr for attr in required_attrs if not hasattr(module, attr)] + if missing: + raise RuntimeError(f"{type(module).__name__} looks like a gated-delta layer but is missing " + f"required attribute(s) for FLA CP: {missing}.") + + if not _callable_accepts_keyword(module.chunk_gated_delta_rule, "cp_context"): + raise ImportError("The installed gated-delta rule kernel does not accept `cp_context`; " + f"install flash-linear-attention/fla-core >= {_LINEAR_ATTENTION_CP_MIN_FLA_VERSION}.") + + +def _gated_delta_cp_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + attention_mask=None, + position_ids: Optional[torch.LongTensor] = None, + *args, + **kwargs, +): + sp_group, sp_world_size, _ = _get_sequence_parallel_info() + if sp_world_size == 1: + return _call_original_linear_attention_forward(self, hidden_states, cache_params, attention_mask, position_ids, + *args, **kwargs) + + if hidden_states.shape[1] == 1 and cache_params is not None \ + and cache_params.has_previous_state(self.layer_idx): + return _call_original_linear_attention_forward(self, hidden_states, cache_params, attention_mask, position_ids, + *args, **kwargs) + if cache_params is not None: + raise RuntimeError("Linear attention CP support under Ulysses SP is training/prefill-only; " + "pass cache_params=None or disable sequence parallelism for this forward.") + unsupported_kwargs = sorted(set(kwargs) - set(_IGNORED_LINEAR_ATTENTION_FORWARD_KWARGS)) + if args or unsupported_kwargs: + raise RuntimeError("Linear attention CP support received unsupported extra forward arguments: " + f"args={len(args)}, kwargs={unsupported_kwargs}") + if hidden_states.ndim != 3: + raise RuntimeError(f"Linear attention CP expects hidden_states with shape [B, S/P, H], " + f"got {tuple(hidden_states.shape)}.") + if hidden_states.shape[0] != 1: + raise RuntimeError("FLA linear attention CP currently requires micro_batch_size == 1.") + + _validate_gated_delta_layer(self) + _, causal_conv1d = _load_linear_attention_cp_ops() + + batch_size, local_seq_len, _ = hidden_states.shape + device = hidden_states.device + cp_context = _build_linear_attention_cp_context( + position_ids=position_ids, + sp_world_size=sp_world_size, + sp_group=sp_group, + conv_kernel_size=self.conv_kernel_size, + local_seq_len=local_seq_len, + device=device, + ) + + hidden_states = _apply_attention_mask_to_hidden_states(hidden_states, attention_mask) + mixed_qkv = self.in_proj_qkv(hidden_states) + z_gate = self.in_proj_z(hidden_states) + beta_logits = self.in_proj_b(hidden_states) + gate_logits = self.in_proj_a(hidden_states) + + conv_result = causal_conv1d( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1).contiguous(), + bias=self.conv1d.bias, + activation=self.activation, + cp_context=cp_context, + ) + qkv = conv_result[0] if isinstance(conv_result, tuple) else conv_result + + num_v_heads = self.num_v_heads + num_k_heads = self.num_k_heads + head_k_dim = self.head_k_dim + head_v_dim = self.head_v_dim + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + expected_qkv_dim = 2 * key_dim + value_dim + if qkv.shape[-1] != expected_qkv_dim: + raise RuntimeError(f"Unexpected gated-delta qkv projection dimension {qkv.shape[-1]}; " + f"expected {expected_qkv_dim}.") + + query = qkv[..., :key_dim].reshape(batch_size, local_seq_len, num_k_heads, head_k_dim) + key = qkv[..., key_dim:2 * key_dim].reshape(batch_size, local_seq_len, num_k_heads, head_k_dim) + value = qkv[..., 2 * key_dim:].reshape(batch_size, local_seq_len, num_v_heads, head_v_dim) + + if num_v_heads % num_k_heads != 0: + raise RuntimeError(f"num_v_heads={num_v_heads} must be a multiple of num_k_heads={num_k_heads}.") + value_heads_per_key_head = num_v_heads // num_k_heads + if value_heads_per_key_head > 1: + query = query.repeat_interleave(value_heads_per_key_head, dim=2) + key = key.repeat_interleave(value_heads_per_key_head, dim=2) + + beta = beta_logits.sigmoid() + gate = -self.A_log.float().exp() * torch.nn.functional.softplus(gate_logits.float() + self.dt_bias) + + core_attn_out, _ = self.chunk_gated_delta_rule( + query, + key, + value, + g=gate, + beta=beta, + cp_context=cp_context, + use_qk_l2norm_in_kernel=True, + **_gated_delta_state_layout_kwargs(self.chunk_gated_delta_rule), + ) + + core_attn_out = core_attn_out.reshape(-1, head_v_dim) + z_gate = z_gate.reshape(-1, head_v_dim) + core_attn_out = self.norm(core_attn_out, z_gate) + core_attn_out = core_attn_out.reshape(batch_size, local_seq_len, num_v_heads * head_v_dim) + return self.out_proj(core_attn_out) + + class UlyssesSPDataLoaderAdapter: def __init__( diff --git a/tests/unit/ulysses_alst/test_ulysses_sp_hf.py b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py index 550233d9239e..098c553209f6 100644 --- a/tests/unit/ulysses_alst/test_ulysses_sp_hf.py +++ b/tests/unit/ulysses_alst/test_ulysses_sp_hf.py @@ -6,6 +6,7 @@ UlyssesPlus: UlyssesSPHF tests """ +import deepspeed.runtime.sequence_parallel.ulysses_sp as ulysses_sp from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPAttentionHF, UlyssesSPDataLoaderAdapter from deepspeed.runtime.utils import move_to_device from deepspeed.utils import groups @@ -17,7 +18,9 @@ import deepspeed import deepspeed.comm as dist import pytest +import sys import torch +import types def get_grad(param, zero_stage): @@ -29,6 +32,177 @@ def get_grad(param, zero_stage): # return safe_get_full_grad(param) +class TestLinearAttentionCPHelpers: + + def test_position_ids_to_packed_cu_seqlens_single_sequence(self): + position_ids = torch.tensor([[0, 1, 2, 3]]) + + cu_seqlens = ulysses_sp._position_ids_to_packed_cu_seqlens(position_ids) + + torch_assert_equal(cu_seqlens, torch.tensor([0, 4], dtype=torch.long)) + + def test_position_ids_to_packed_cu_seqlens_packed_sequence(self): + position_ids = torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]) + + cu_seqlens = ulysses_sp._position_ids_to_packed_cu_seqlens(position_ids) + + torch_assert_equal(cu_seqlens, torch.tensor([0, 3, 5, 8], dtype=torch.long)) + + def test_modeling_module_candidates_strip_text_suffix(self): + cfg = types.SimpleNamespace(model_type="qwen3_5_text") + + candidates = list(ulysses_sp._modeling_module_candidates(cfg, cfg)) + + assert "transformers.models.qwen3_5_text.modeling_qwen3_5_text" in candidates + assert "transformers.models.qwen3_5.modeling_qwen3_5" in candidates + + def test_linear_attention_cp_noops_for_non_linear_config(self, monkeypatch): + + def fail_if_called(_name): + raise AssertionError("FLA package version should not be probed for non-linear configs") + + monkeypatch.setattr(ulysses_sp.importlib_metadata, "version", fail_if_called) + cfg = types.SimpleNamespace(model_type="llama", layer_types=["full_attention"]) + + assert ulysses_sp._register_linear_attention_cp(cfg, cfg) == 0 + + def test_linear_attention_cp_version_gate(self, monkeypatch): + + def fake_version(_name): + return "0.4.1" + + monkeypatch.setattr(ulysses_sp.importlib_metadata, "version", fake_version) + + with pytest.raises(ImportError, match=">= 0.4.2"): + ulysses_sp._load_linear_attention_cp_ops() + + def test_gated_delta_state_layout_kwargs_match_fla_version_signatures(self): + + def old_chunk_gated_delta_rule(transpose_state_layout=False, **kwargs): + return transpose_state_layout, kwargs + + def new_chunk_gated_delta_rule(state_v_first=False, **kwargs): + return state_v_first, kwargs + + assert ulysses_sp._gated_delta_state_layout_kwargs(old_chunk_gated_delta_rule) == { + "transpose_state_layout": True + } + assert ulysses_sp._gated_delta_state_layout_kwargs(new_chunk_gated_delta_rule) == {"state_v_first": True} + + def test_linear_attention_cp_ignores_transformers_forward_flags(self, monkeypatch): + + class FakeNorm(torch.nn.Module): + + def forward(self, hidden_states, gate): + return hidden_states + + class FakeGatedDeltaNet(torch.nn.Module): + + def __init__(self): + super().__init__() + self.in_proj_qkv = torch.nn.Linear(4, 12, bias=False) + self.in_proj_z = torch.nn.Linear(4, 4, bias=False) + self.in_proj_b = torch.nn.Linear(4, 1, bias=False) + self.in_proj_a = torch.nn.Linear(4, 1, bias=False) + self.conv1d = torch.nn.Conv1d(12, 12, kernel_size=1, groups=12) + self.activation = "silu" + self.num_v_heads = 1 + self.num_k_heads = 1 + self.head_k_dim = 4 + self.head_v_dim = 4 + self.conv_kernel_size = 1 + self.A_log = torch.nn.Parameter(torch.zeros(1)) + self.dt_bias = torch.nn.Parameter(torch.zeros(1)) + self.norm = FakeNorm() + self.out_proj = torch.nn.Linear(4, 4, bias=False) + + def chunk_gated_delta_rule(self, query, key, value, g=None, beta=None, cp_context=None, **kwargs): + return value, None + + def fake_causal_conv1d(x, weight=None, bias=None, activation=None, cp_context=None): + return x + + monkeypatch.setattr(ulysses_sp, "_get_sequence_parallel_info", lambda: (None, 2, 0)) + monkeypatch.setattr(ulysses_sp, "_load_linear_attention_cp_ops", lambda: (None, fake_causal_conv1d)) + monkeypatch.setattr(ulysses_sp, "_build_linear_attention_cp_context", lambda **kwargs: object()) + + layer = FakeGatedDeltaNet() + hidden_states = torch.randn(1, 2, 4) + + output = ulysses_sp._gated_delta_cp_forward( + layer, + hidden_states, + use_cache=False, + output_hidden_states=True, + output_attentions=False, + return_dict=True, + ) + + assert output.shape == hidden_states.shape + + def test_linear_attention_cp_patches_gdn_like_class(self, monkeypatch): + modeling_module_name = "transformers.models.fake_linear.modeling_fake_linear" + modeling_module = types.ModuleType(modeling_module_name) + + class FakeGatedDeltaNet(torch.nn.Module): + + def forward(self, hidden_states, cache_params=None, attention_mask=None, position_ids=None): + return hidden_states + + modeling_module.FakeGatedDeltaNet = FakeGatedDeltaNet + + fla_module = types.ModuleType("fla") + fla_ops_module = types.ModuleType("fla.ops") + fla_cp_module = types.ModuleType("fla.ops.cp") + fla_modules_module = types.ModuleType("fla.modules") + fla_conv_module = types.ModuleType("fla.modules.conv") + + class FakeFLACPContext: + pass + + def fake_build_cp_context(*args, **kwargs): + return FakeFLACPContext() + + def fake_causal_conv1d(x, weight=None, bias=None, activation=None, cp_context=None): + return x, None + + fla_cp_module.FLACPContext = FakeFLACPContext + fla_cp_module.build_cp_context = fake_build_cp_context + fla_conv_module.causal_conv1d = fake_causal_conv1d + + fla_module.ops = fla_ops_module + fla_ops_module.cp = fla_cp_module + fla_module.modules = fla_modules_module + fla_modules_module.conv = fla_conv_module + + for name, module in ( + ("fla", fla_module), + ("fla.ops", fla_ops_module), + ("fla.ops.cp", fla_cp_module), + ("fla.modules", fla_modules_module), + ("fla.modules.conv", fla_conv_module), + (modeling_module_name, modeling_module), + ): + monkeypatch.setitem(sys.modules, name, module) + + monkeypatch.setattr(ulysses_sp.importlib_metadata, "version", lambda _name: "0.5.0") + + cfg = types.SimpleNamespace(model_type="fake_linear", layer_types=["linear_attention"]) + original_forward = FakeGatedDeltaNet.forward + + try: + installed = ulysses_sp._register_linear_attention_cp(cfg, cfg) + installed_again = ulysses_sp._register_linear_attention_cp(cfg, cfg) + + assert installed == 1 + assert installed_again == 0 + assert FakeGatedDeltaNet.forward is ulysses_sp._gated_delta_cp_forward + assert ulysses_sp._LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS[FakeGatedDeltaNet] is original_forward + finally: + FakeGatedDeltaNet.forward = original_forward + ulysses_sp._LINEAR_ATTENTION_CP_ORIGINAL_FORWARDS.pop(FakeGatedDeltaNet, None) + + @pytest.mark.parametrize("zero_stage", [2, 3]) class TestUlyssesSPHF(DistributedTest): world_size = 2