Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions diffsynth_engine/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,29 @@ def get_ranks(self, token):
return ranks


# WORLD
def is_world_group_initialized() -> bool:
return _WORLD is not None


def get_world_group() -> GroupCoordinator:
assert _WORLD is not None, "world group is not initialized"
return _WORLD


def get_global_world_size():
return get_world_group().world_size


def get_global_rank():
return get_world_group().rank_in_group


# TP
def is_tp_group_initialized() -> bool:
return _TP is not None


def get_tp_group() -> GroupCoordinator:
assert _TP is not None, "tensor model parallel group is not initialized"
return _TP
Expand All @@ -236,6 +253,10 @@ def get_tensor_model_parallel_rank():


# SP
def is_sp_group_initialized() -> bool:
return _SP is not None


def get_sp_group() -> SequenceParallelGroupCoordinator:
assert _SP is not None, "pipeline model parallel group is not initialized"
return _SP
Expand Down Expand Up @@ -268,6 +289,10 @@ def get_ring_parallel_rank():


# PP
def is_pp_group_initialized() -> bool:
return _PP is not None


def get_pp_group() -> PipelineGroupCoordinator:
assert _PP is not None, "pipeline model parallel group is not initialized"
return _PP
Expand All @@ -294,6 +319,10 @@ def is_pipeline_last_stage():


# CFG
def is_cfg_group_initialized() -> bool:
return _CFG is not None


def get_cfg_group() -> GroupCoordinator:
assert _CFG is not None, "classifier_free_guidance parallel group is not initialized"
return _CFG
Expand All @@ -310,6 +339,10 @@ def get_classifier_free_guidance_rank():


# DP
def is_dp_group_initialized() -> bool:
return _DP is not None


def get_dp_group() -> GroupCoordinator:
assert _DP is not None, "pipeline model parallel group is not initialized"
return _DP
Expand Down Expand Up @@ -346,6 +379,10 @@ def get_dit_world_size():


# VAE
def is_vae_group_initialized() -> bool:
return _VAE is not None


def get_vae_parallel_group() -> GroupCoordinator:
assert _VAE is not None, "VAE parallel group is not initialized"
return _VAE
Expand Down Expand Up @@ -491,6 +528,10 @@ def init_dit_group(
_DIT = torch.distributed.new_group(ranks=list(range(dit_parallel_size)), backend=backend)


def is_dit_group_initialized() -> bool:
return _DIT is not None


def get_dit_group():
assert _DIT is not None, "DIT group is not initialized"
return _DIT
Expand Down
4 changes: 4 additions & 0 deletions diffsynth_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def shutdown(self):

self.workers = None
self.conns = None

if self.pipeline is not None:
del self.pipeline
self.pipeline = None

def start_profile(self, path: str = ".", profile_rank0_only: bool = True):
if self.workers is not None:
Expand Down
6 changes: 3 additions & 3 deletions diffsynth_engine/layers/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.distributed as dist
import torch.nn as nn

from diffsynth_engine.distributed.comm import SeqAllToAll4D
from diffsynth_engine.distributed.parallel_state import (
get_ring_parallel_world_size,
get_sp_group,
get_ulysses_parallel_world_size,
is_sp_group_initialized,
)
from diffsynth_engine.forward_context import ForwardContext, get_forward_context
from diffsynth_engine.layers.attention.backends.abstract import AttentionType
Expand Down Expand Up @@ -139,8 +139,8 @@ def forward(
attn_kwargs = {"attn_metadata": attn_metadata}
attn_kwargs.update(kwargs)

ulysses_parallel_world_size = get_ulysses_parallel_world_size() if dist.is_initialized() else 1
ring_parallel_world_size = get_ring_parallel_world_size() if dist.is_initialized() else 1
ulysses_parallel_world_size = get_ulysses_parallel_world_size() if is_sp_group_initialized() else 1
ring_parallel_world_size = get_ring_parallel_world_size() if is_sp_group_initialized() else 1

if ulysses_parallel_world_size > 1:
q = SeqAllToAll4D.apply(get_sp_group().ulysses_group, q, self.scatter_idx, self.gather_idx)
Expand Down
7 changes: 7 additions & 0 deletions diffsynth_engine/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
class DiffusionModel(nn.Module, ConfigMixin):
config_name = CONFIG_NAME

@property
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@molepi40 这个改动好像没有作用,删了吧

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于 diffuers 提供的实现无法跑通,现在 qwen image layered 的 pipeline 用 from diffsynth_engine.models.qwen_image.autoencoder_kl_qwenimage import AutoencoderKLQwenImage 替换了原有的 from diffusers.models import AutoencoderKLQwenImage,而 engine 中的实现没有提供 dtype 属性,需要自己实现,或者继承 diffusers 的 ModelMixin

def dtype(self) -> torch.dtype:
param = next(self.parameters(), None)
if param is None:
raise RuntimeError(f"{type(self).__name__} has no parameters, cannot determine dtype")
return param.dtype

@classmethod
def from_pretrained(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# - Paper: https://huggingface.co/papers/2503.20314

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import register_to_config
Expand Down Expand Up @@ -1167,8 +1166,8 @@ def parallel_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:

self.clear_cache()

dist.all_reduce(values, group=group)
dist.all_reduce(weight, group=group)
group.all_reduce(values)
group.all_reduce(weight)

enc = values / weight
return enc
Expand Down Expand Up @@ -1247,8 +1246,8 @@ def parallel_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> De

self.clear_cache()

dist.all_reduce(values, group=group)
dist.all_reduce(weight, group=group)
group.all_reduce(values)
group.all_reduce(weight)

dec = values / weight
dec = torch.clamp(dec, min=-1.0, max=1.0)
Expand Down
131 changes: 129 additions & 2 deletions diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from typing import Type

import torch
import torch.distributed as dist
import torch.nn as nn
from accelerate import init_empty_weights
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
from tqdm import tqdm

from diffsynth_engine.configs import PipelineConfig
from diffsynth_engine.distributed.parallel_state import get_global_rank, is_world_group_initialized
from diffsynth_engine.forward_context import set_forward_context
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.load_utils import fix_state_dict_key, load_model_weights

logger = logging.get_logger(__name__)


class Pipeline:
Expand All @@ -17,6 +28,122 @@ def from_pretrained(cls, model_path_or_config: str | PipelineConfig):
def __call__(self, *args, **kwargs):
raise NotImplementedError()

@staticmethod
def init_transformer(model_cls: Type[nn.Module], pipeline_config: PipelineConfig, empty_weights: bool = False):
use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized()

with set_forward_context(attn_type=pipeline_config.attn_type):
with init_empty_weights():
config = model_cls.load_config(
pipeline_config.model_path,
subfolder="transformer",
local_files_only=True,
)
model = model_cls.from_config(config)

if empty_weights:
return model

if use_fsdp:
for block in model.transformer_blocks:
fully_shard(block)
fully_shard(model)

state_dict = load_model_weights(
pipeline_config.model_path,
subfolder="transformer",
device="cpu" if use_fsdp else pipeline_config.device,
dtype=pipeline_config.model_dtype,
broadcast_from_rank0=not use_fsdp,
)

if use_fsdp:
set_model_state_dict(
model,
state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
else:
model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=pipeline_config.device)

del state_dict
return model

@staticmethod
def init_text_encoder(
model_cls: Type[nn.Module],
pipeline_config: PipelineConfig,
key_mapping: dict = None,
empty_weights: bool = False,
):
use_fsdp = pipeline_config.use_fsdp and is_world_group_initialized()

with init_empty_weights():
config = model_cls.config_class.from_pretrained(
pipeline_config.model_path,
subfolder="text_encoder",
local_files_only=True,
)
model = model_cls(config)

if empty_weights:
return model

if use_fsdp:
for layer in model.model.language_model.layers:
fully_shard(layer)
Comment on lines +94 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attribute path model.model.language_model.layers appears to be incorrect for Qwen2_5_VLForConditionalGeneration (which is the model class used in the QwenImage pipelines). In the standard transformers implementation for Qwen2.5-VL, the layers are located at model.model.layers. Using the current path will result in an AttributeError when FSDP is enabled.

Suggested change
for layer in model.model.language_model.layers:
fully_shard(layer)
if use_fsdp:
for layer in model.model.layers:
fully_shard(layer)
fully_shard(model)

fully_shard(model)

state_dict = load_model_weights(
pipeline_config.model_path,
subfolder="text_encoder",
device="cpu" if use_fsdp else pipeline_config.device,
dtype=pipeline_config.text_encoder_dtype,
broadcast_from_rank0=not use_fsdp,
)

if key_mapping:
state_dict = fix_state_dict_key(state_dict, key_mapping)

if use_fsdp:
set_model_state_dict(
model,
state_dict,
options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True),
)
else:
model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=pipeline_config.device)

del state_dict
return model

@staticmethod
def init_vae(model_cls: Type[nn.Module], pipeline_config: PipelineConfig, empty_weights: bool = False):
with init_empty_weights():
config = model_cls.load_config(
pipeline_config.model_path,
subfolder="vae",
local_files_only=True,
)
model = model_cls.from_config(config)

if empty_weights:
return model

state_dict = load_model_weights(
pipeline_config.model_path,
subfolder="vae",
device=pipeline_config.device,
dtype=pipeline_config.vae_dtype,
)
model.load_state_dict(state_dict, strict=True, assign=True)
model.to(device=pipeline_config.device)

del state_dict
return model

@torch.compiler.disable
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
Expand All @@ -28,7 +155,7 @@ def progress_bar(self, iterable=None, total=None):

progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
is_rank_zero = not dist.is_initialized() or dist.get_rank() == 0
is_rank_zero = not is_world_group_initialized() or get_global_rank() == 0
progress_bar_config["disable"] = not is_rank_zero

if iterable is not None:
Expand Down
Loading