From 2354fda9c0d66ae1e1606463a767b7b8173ff73a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:44:12 +0000 Subject: [PATCH 01/38] init --- ...convert_z_image_controlnet_to_diffusers.py | 103 +++ src/diffusers/models/controlnets/__init__.py | 1 + .../models/controlnets/controlnet_z_image.py | 528 ++++++++++++++ .../transformers/transformer_z_image.py | 11 +- .../z_image/pipeline_z_image_controlnet.py | 674 ++++++++++++++++++ 5 files changed, 1315 insertions(+), 2 deletions(-) create mode 100644 scripts/convert_z_image_controlnet_to_diffusers.py create mode 100644 src/diffusers/models/controlnets/controlnet_z_image.py create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py new file mode 100644 index 000000000000..c4b96cda02af --- /dev/null +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -0,0 +1,103 @@ +import argparse +from contextlib import nullcontext + +import torch +import safetensors.torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.models import ZImageTransformer2DModel +from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel + +""" +python scripts/convert_z_image_controlnet_to_diffusers.py \ +--original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ +--original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ +--filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" +--output_path "z-image-controlnet-hf/" +""" + + +CTX = init_empty_weights if is_accelerate_available else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) +parser.add_argument("--original_controlnet_repo_id", default=None, type=str) +parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", type=str) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_controlnet_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_controlnet_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_controlnet_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + +def load_z_image(args): + model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + return model.state_dict(), model.config + +def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): + converted_state_dict = {} + + converted_state_dict.update(original_state_dict) + + to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + + for key in z_image.keys(): + for copy_key in to_copy: + if key.startswith(copy_key): + converted_state_dict[key] = z_image[key] + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + z_image, config = load_z_image(args) + + control_in_dim = 16 + control_layers_places = [0, 5, 10, 15, 20, 25] + + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) + + for key, tensor in converted_controlnet_state_dict.items(): + print(f"{key} - {tensor.dtype}") + + controlnet = ZImageControlNetModel( + all_patch_size=config["all_patch_size"], + all_f_patch_size=config["all_f_patch_size"], + in_channels=config["in_channels"], + dim=config["dim"], + n_layers=config["n_layers"], + n_refiner_layers=config["n_refiner_layers"], + n_heads=config["n_heads"], + n_kv_heads=config["n_kv_heads"], + norm_eps=config["norm_eps"], + qk_norm=config["qk_norm"], + cap_feat_dim=config["cap_feat_dim"], + rope_theta=config["rope_theta"], + t_scale=config["t_scale"], + axes_dims=config["axes_dims"], + axes_lens=config["axes_lens"], + control_layers_places=control_layers_places, + control_in_dim=control_in_dim, + ) + missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) + print(f"{missing=}") + print(f"{unexpected=}") + print("Saving Z-Image ControlNet in Diffusers format") + controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 7ce352879daa..fee7f231e899 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -19,6 +19,7 @@ ) from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .controlnet_z_image import ZImageControlNetModel from .multicontrolnet import MultiControlNetModel from .multicontrolnet_union import MultiControlNetUnionModel diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py new file mode 100644 index 000000000000..d6cede86812d --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -0,0 +1,528 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...models.normalization import RMSNorm +from ..controlnets.controlnet import zero_module +from ..modeling_utils import ModelMixin +from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM + + +class ZImageControlTransformerBlock(ZImageTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0 + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + self.block_id = block_id + if block_id == 0: + self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) + self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) + + def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + +class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + control_layers_places: List[int]=None, + control_in_dim=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + self.n_layers = n_layers + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear(cap_feat_dim, dim, bias=True), + ) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + ## Original Control layers + + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + block_id=i + ) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + cap_padding_len: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_image_size, + all_image_pos_ids, + all_image_pad_mask, + ) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + all_cap_pad_mask.append( + torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) + ) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + all_image_pos_ids.append(image_padded_pos_ids) + # pad mask + all_image_pad_mask.append( + torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + ) + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + cap_feats: List[torch.Tensor], + control_context: List[torch.Tensor], + t=None, + patch_size=2, + f_patch_size=1, + conditioning_scale: float = 1.0, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## Original forward_control + + # embeddings + bsz = len(control_context) + device = control_context[0].device + ( + control_context, + x_size, + x_pos_ids, + x_inner_pad_mask, + ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) + + # control_context embed & refine + x_item_seqlens = [len(_) for _ in control_context] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + # Match t_embedder output dtype to control_context for layerwise casting compatibility + adaln_input = t.type_as(control_context) + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + cap_item_seqlens = [len(_) for _ in cap_feats] + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + c = control_context_unified + + new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + c = self._gradient_checkpointing_func(layer, c, **new_kwargs) + else: + c = layer(c, **new_kwargs) + + hints = torch.unbind(c)[:-1] * conditioning_scale + controlnet_block_samples = {} + for layer_idx in range(self.n_layers): + if layer_idx in self.control_layers_places: + hints_idx = self.control_layers_places.index(layer_idx) + controlnet_block_samples[layer_idx] = hints[hints_idx] + + return controlnet_block_samples diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..2d332217d897 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,6 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, + controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, return_dict: bool = True, ): assert patch_size in self.all_patch_size @@ -635,13 +636,19 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] else: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py new file mode 100644 index 000000000000..609b141be796 --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -0,0 +1,674 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + prompt_embeds_model_input, + control_image, + timestep_model_input, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) From 1e2009de435516caf7b6e67ab215f8f6299c375f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:48:38 +0000 Subject: [PATCH 02/38] passed transformer --- .../models/controlnets/controlnet_z_image.py | 98 +++---------------- .../z_image/pipeline_z_image_controlnet.py | 1 + 2 files changed, 15 insertions(+), 84 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d6cede86812d..6fe9d38ce3d1 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -23,7 +23,7 @@ from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM class ZImageControlTransformerBlock(ZImageTransformerBlock): @@ -66,87 +66,16 @@ def __init__( self, all_patch_size=(2,), all_f_patch_size=(1,), - in_channels=16, dim=3840, - n_layers=30, n_refiner_layers=2, n_heads=30, n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - cap_feat_dim=2560, - rope_theta=256.0, - t_scale=1000.0, - axes_dims=[32, 48, 48], - axes_lens=[1024, 512, 512], control_layers_places: List[int]=None, control_in_dim=None, ): super().__init__() - self.in_channels = in_channels - self.out_channels = in_channels - self.all_patch_size = all_patch_size - self.all_f_patch_size = all_f_patch_size - self.dim = dim - self.n_heads = n_heads - - self.rope_theta = rope_theta - self.t_scale = t_scale - self.gradient_checkpointing = False - self.n_layers = n_layers - - assert len(all_patch_size) == len(all_f_patch_size) - - all_x_embedder = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - self.all_x_embedder = nn.ModuleDict(all_x_embedder) - self.noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.context_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), - ) - - self.x_pad_token = nn.Parameter(torch.empty((1, dim))) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) - - self.axes_dims = axes_dims - self.axes_lens = axes_lens - - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - - ## Original Control layers - self.control_layers_places = control_layers_places self.control_in_dim = control_in_dim @@ -366,6 +295,7 @@ def patchify_and_embed( def forward( self, + transformer: ZImageTransformer2DModel, x: List[torch.Tensor], cap_feats: List[torch.Tensor], control_context: List[torch.Tensor], @@ -380,7 +310,7 @@ def forward( bsz = len(x) device = x[0].device t = t * self.t_scale - t = self.t_embedder(t) + t = transformer.t_embedder(t) ( x, @@ -398,13 +328,13 @@ def forward( x_max_item_seqlen = max(x_item_seqlens) x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) # Match t_embedder output dtype to x for layerwise casting compatibility adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) @@ -413,10 +343,10 @@ def forward( x_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) else: - for layer in self.noise_refiner: + for layer in transformer.noise_refiner: x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) # cap embed & refine @@ -425,10 +355,10 @@ def forward( cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = transformer.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) @@ -437,10 +367,10 @@ def forward( cap_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) else: - for layer in self.context_refiner: + for layer in transformer.context_refiner: cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) # unified @@ -485,7 +415,7 @@ def forward( adaln_input = t.type_as(control_context) control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 609b141be796..d374b8032ea8 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -594,6 +594,7 @@ def __call__( latent_model_input_list = list(latent_model_input.unbind(dim=0)) controlnet_block_samples = self.controlnet( + self.transformer, latent_model_input_list, prompt_embeds_model_input, control_image, From 0c308394049f2c7a65c697cf88ea9c40d9ca4333 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 17:50:01 +0000 Subject: [PATCH 03/38] ruff --- ...convert_z_image_controlnet_to_diffusers.py | 21 ++++++++--- .../models/controlnets/controlnet_z_image.py | 36 +++++++++---------- .../transformers/transformer_z_image.py | 2 +- .../z_image/pipeline_z_image_controlnet.py | 3 +- 4 files changed, 36 insertions(+), 26 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index c4b96cda02af..a9f97d81676d 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,14 +1,15 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.utils.import_utils import is_accelerate_available from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel +from diffusers.utils.import_utils import is_accelerate_available + """ python scripts/convert_z_image_controlnet_to_diffusers.py \ @@ -42,16 +43,28 @@ def load_original_checkpoint(args): original_state_dict = safetensors.torch.load_file(ckpt_path) return original_state_dict + def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained(args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16) + model = ZImageTransformer2DModel.from_pretrained( + args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ) return model.state_dict(), model.config + def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = {"all_x_embedder.", "noise_refiner.", "context_refiner.", "t_embedder.", "cap_embedder.", "x_pad_token", "cap_pad_token"} + to_copy = { + "all_x_embedder.", + "noise_refiner.", + "context_refiner.", + "t_embedder.", + "cap_embedder.", + "x_pad_token", + "cap_pad_token", + } for key in z_image.keys(): for copy_key in to_copy: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 6fe9d38ce3d1..b76a2c54c3d8 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -20,15 +20,18 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...models.normalization import RMSNorm from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM +from ..transformers.transformer_z_image import ( + SEQ_MULTI_OF, + ZImageTransformer2DModel, + ZImageTransformerBlock, +) class ZImageControlTransformerBlock(ZImageTransformerBlock): def __init__( - self, + self, layer_id: int, dim: int, n_heads: int, @@ -36,7 +39,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, - block_id=0 + block_id=0, ): super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) self.block_id = block_id @@ -57,7 +60,8 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): all_c += [c_skip, c] c = torch.stack(all_c) return c - + + class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @@ -72,7 +76,7 @@ def __init__( n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - control_layers_places: List[int]=None, + control_layers_places: List[int] = None, control_in_dim=None, ): super().__init__() @@ -84,15 +88,7 @@ def __init__( # control blocks self.control_layers = nn.ModuleList( [ - ZImageControlTransformerBlock( - i, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - block_id=i - ) + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places ] ) @@ -425,7 +421,9 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func(layer, control_context, x_attn_mask, x_freqs_cis, adaln_input) + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) else: for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) @@ -440,14 +438,14 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = dict(x=unified, attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) - + new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: c = self._gradient_checkpointing_func(layer, c, **new_kwargs) else: c = layer(c, **new_kwargs) - + hints = torch.unbind(c)[:-1] * conditioning_scale controlnet_block_samples = {} for layer_idx in range(self.n_layers): diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 2d332217d897..70ffced8b63a 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,7 +538,7 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, - controlnet_block_samples: Optional[dict[int, torch.Tensor]]=None, + controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None, return_dict: bool = True, ): assert patch_size in self.all_patch_size diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index d374b8032ea8..44906a0db519 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -89,7 +89,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -509,7 +508,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, - ) + ) height, width = control_image.shape[-2:] control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor From 52f996e226dfd1e7f1a1b0d001c022dae71e24a8 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:23 +0000 Subject: [PATCH 04/38] convert passed --- ...convert_z_image_controlnet_to_diffusers.py | 58 ++----------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index a9f97d81676d..aed27c14f205 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,19 +1,17 @@ import argparse from contextlib import nullcontext -import safetensors.torch import torch +import safetensors.torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers.models import ZImageTransformer2DModel from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel from diffusers.utils.import_utils import is_accelerate_available """ python scripts/convert_z_image_controlnet_to_diffusers.py \ ---original_z_image_repo_id "Tongyi-MAI/Z-Image-Turbo" \ --original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ --filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" --output_path "z-image-controlnet-hf/" @@ -23,7 +21,6 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext parser = argparse.ArgumentParser() -parser.add_argument("--original_z_image_repo_id", default="Tongyi-MAI/Z-Image-Turbo", type=str) parser.add_argument("--original_controlnet_repo_id", default=None, type=str) parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) parser.add_argument("--checkpoint_path", default=None, type=str) @@ -44,72 +41,29 @@ def load_original_checkpoint(args): return original_state_dict -def load_z_image(args): - model = ZImageTransformer2DModel.from_pretrained( - args.original_z_image_repo_id, subfolder="transformer", torch_dtype=torch.bfloat16 - ) - return model.state_dict(), model.config - - -def convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_state_dict): +def convert_z_image_controlnet_checkpoint_to_diffusers(original_state_dict): converted_state_dict = {} converted_state_dict.update(original_state_dict) - to_copy = { - "all_x_embedder.", - "noise_refiner.", - "context_refiner.", - "t_embedder.", - "cap_embedder.", - "x_pad_token", - "cap_pad_token", - } - - for key in z_image.keys(): - for copy_key in to_copy: - if key.startswith(copy_key): - converted_state_dict[key] = z_image[key] - return converted_state_dict def main(args): original_ckpt = load_original_checkpoint(args) - z_image, config = load_z_image(args) control_in_dim = 16 control_layers_places = [0, 5, 10, 15, 20, 25] - converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(z_image, original_ckpt) - - for key, tensor in converted_controlnet_state_dict.items(): - print(f"{key} - {tensor.dtype}") + converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(original_ckpt) controlnet = ZImageControlNetModel( - all_patch_size=config["all_patch_size"], - all_f_patch_size=config["all_f_patch_size"], - in_channels=config["in_channels"], - dim=config["dim"], - n_layers=config["n_layers"], - n_refiner_layers=config["n_refiner_layers"], - n_heads=config["n_heads"], - n_kv_heads=config["n_kv_heads"], - norm_eps=config["norm_eps"], - qk_norm=config["qk_norm"], - cap_feat_dim=config["cap_feat_dim"], - rope_theta=config["rope_theta"], - t_scale=config["t_scale"], - axes_dims=config["axes_dims"], - axes_lens=config["axes_lens"], control_layers_places=control_layers_places, control_in_dim=control_in_dim, - ) - missing, unexpected = controlnet.load_state_dict(converted_controlnet_state_dict) - print(f"{missing=}") - print(f"{unexpected=}") + ).to(torch.bfloat16) + controlnet.load_state_dict(converted_controlnet_state_dict) print("Saving Z-Image ControlNet in Diffusers format") - controlnet.save_pretrained(args.output_path, max_shard_size="5GB") + controlnet.save_pretrained(args.output_path) if __name__ == "__main__": From 4b446b394150575b322836b048acd3eeeb2072a3 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:30 +0000 Subject: [PATCH 05/38] __init__ --- src/diffusers/__init__.py | 4 ++++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/z_image/__init__.py | 2 ++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb8e86c4c89d..f45be1560716 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageControlNetModel", "attention_backend", ] ) @@ -661,6 +662,7 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", "ZImagePipeline", + "ZImageControlNetPipeline", ] ) @@ -1004,6 +1006,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageTransformer2DModel, + ZImageControlNetModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1357,6 +1360,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ZImagePipeline, + ZImageControlNetPipeline, ) try: diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29d8b0b5a55d..7ea15ef2a215 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -66,6 +66,7 @@ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] @@ -180,6 +181,7 @@ SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, + ZImageControlNetModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3d669aecf556..fe6af5cd1e0b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -404,7 +404,7 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImagePipeline"] + _import_structure["z_image"] = ["ZImagePipeline", "ZImageControlNetPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline + from .z_image import ZImagePipeline, ZImageControlNetPipeline try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index f95b3e5a0bed..842d5690e3d7 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -23,6 +23,7 @@ else: _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] + _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -35,6 +36,7 @@ else: from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline + from .pipeline_z_image_controlnet import ZImageControlNetPipeline else: import sys From a1ff390ecebb5afb9f6282209526adfdaf31c5d5 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:03:37 +0000 Subject: [PATCH 06/38] pipeline example --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 44906a0db519..ae81105eea27 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -36,9 +36,13 @@ Examples: ```py >>> import torch - >>> from diffusers import ZImagePipeline + >>> from diffusers import ZImageControlNetPipeline + >>> from diffusers import ZImageControlNetModel - >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> controlnet_model = "..." + >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + + >>> pipe = ZImageControlNetPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. @@ -47,9 +51,11 @@ >>> # (2) Use flash attention 3 >>> # pipe.transformer.set_attention_backend("_flash_3") - >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") + >>> prompt = "A girl in city, 25 years old, cool, futuristic" >>> image = pipe( ... prompt, + ... control_image=control_image, ... height=1024, ... width=1024, ... num_inference_steps=9, From 7ab347d812a5b78076f79541f4046d59948d0464 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:04:01 +0000 Subject: [PATCH 07/38] ruff --- scripts/convert_z_image_controlnet_to_diffusers.py | 2 +- src/diffusers/__init__.py | 4 ++-- src/diffusers/models/controlnets/controlnet_z_image.py | 7 ++++++- src/diffusers/pipelines/__init__.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py index aed27c14f205..e5d5f34e36e8 100644 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ b/scripts/convert_z_image_controlnet_to_diffusers.py @@ -1,8 +1,8 @@ import argparse from contextlib import nullcontext -import torch import safetensors.torch +import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f45be1560716..398f72167ad3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1005,8 +1005,8 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, - ZImageTransformer2DModel, ZImageControlNetModel, + ZImageTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1359,8 +1359,8 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - ZImagePipeline, ZImageControlNetPipeline, + ZImagePipeline, ) try: diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index b76a2c54c3d8..ff148781f49a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -438,7 +438,12 @@ def forward( control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) c = control_context_unified - new_kwargs = {"x": unified, "attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} + new_kwargs = { + "x": unified, + "attn_mask": unified_attn_mask, + "freqs_cis": unified_freqs_cis, + "adaln_input": adaln_input, + } for layer in self.control_layers: if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index fe6af5cd1e0b..10ce49fe8111 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -841,7 +841,7 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImagePipeline, ZImageControlNetPipeline + from .z_image import ZImageControlNetPipeline, ZImagePipeline try: if not is_onnx_available(): From 8cab0c953c7b732324a92d4e5de067b8bb290a5d Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:05:36 +0000 Subject: [PATCH 08/38] pipeline load_image --- src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index ae81105eea27..67771dddabd7 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -38,6 +38,7 @@ >>> import torch >>> from diffusers import ZImageControlNetPipeline >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image >>> controlnet_model = "..." >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) From 8688fa66a110bd0d77df8d03299fc3a42130ce07 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:24:05 +0000 Subject: [PATCH 09/38] t_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index ff148781f49a..070724a85883 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -305,7 +305,7 @@ def forward( bsz = len(x) device = x[0].device - t = t * self.t_scale + t = t * transformer.t_scale t = transformer.t_embedder(t) ( From 9051272d47082c5cf6bc409b6368332cdac16f97 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:26 +0000 Subject: [PATCH 10/38] x_pad_token --- src/diffusers/models/controlnets/controlnet_z_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 070724a85883..48b9a66a25a3 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -409,7 +409,7 @@ def forward( # Match t_embedder output dtype to control_context for layerwise casting compatibility adaln_input = t.type_as(control_context) - control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token control_context = list(control_context.split(x_item_seqlens, dim=0)) x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) From 0d8c3f1a28180fc85fc9a4e0696d5f4f11def56f Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:28:34 +0000 Subject: [PATCH 11/38] controlnet_block_samples --- src/diffusers/models/controlnets/controlnet_z_image.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 48b9a66a25a3..0127f7f9683f 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,10 +452,6 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {} - for layer_idx in range(self.n_layers): - if layer_idx in self.control_layers_places: - hints_idx = self.control_layers_places.index(layer_idx) - controlnet_block_samples[layer_idx] = hints[hints_idx] + controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From f789325ccd8f3f6fb35dffdd4acea6f21f30084e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:29:54 +0000 Subject: [PATCH 12/38] conditioning_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 0127f7f9683f..3a200b252a01 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -451,7 +451,7 @@ def forward( else: c = layer(c, **new_kwargs) - hints = torch.unbind(c)[:-1] * conditioning_scale - controlnet_block_samples = {layer_idx: hints[idx] for idx, layer_idx in enumerate(self.control_layers_places)} + hints = torch.unbind(c)[:-1] + controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} return controlnet_block_samples From 5f8ab7bf98549ff6bdc63db500ad1433b7cf84e2 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 18:33:36 +0000 Subject: [PATCH 13/38] self.config --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 3a200b252a01..d0f8b861e0c9 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -300,8 +300,8 @@ def forward( f_patch_size=1, conditioning_scale: float = 1.0, ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.config.all_patch_size + assert f_patch_size in self.config.all_f_patch_size bsz = len(x) device = x[0].device From bc72f9ce93ca691018fb8f1b420684b6e18a6d55 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:08:24 +0000 Subject: [PATCH 14/38] sample_mode, default controlnet_conditioning_scale --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 67771dddabd7..d0460cf09244 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -360,7 +360,7 @@ def __call__( sigmas: Optional[List[float]] = None, guidance_scale: float = 5.0, control_image: PipelineImageInput = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, cfg_normalization: bool = False, cfg_truncation: float = 1.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -517,7 +517,7 @@ def __call__( dtype=self.vae.dtype, ) height, width = control_image.shape[-2:] - control_image = retrieve_latents(self.vae.encode(control_image), generator=generator) + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = control_image.unsqueeze(2) From 13b706a99f209197352bcb8790260727d75b2b9b Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 4 Dec 2025 20:16:49 +0000 Subject: [PATCH 15/38] ruff --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index d0f8b861e0c9..c121f42c1a78 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -452,6 +452,8 @@ def forward( c = layer(c, **new_kwargs) hints = torch.unbind(c)[:-1] - controlnet_block_samples = {layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)} + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } return controlnet_block_samples From 09849a77465e49f6d1ca056638d917b6937f4b95 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 20:48:16 +0000 Subject: [PATCH 16/38] ZImageControlTransformer2DModel --- src/diffusers/__init__.py | 2 + src/diffusers/loaders/peft.py | 1 + src/diffusers/models/__init__.py | 2 + .../models/controlnets/controlnet_z_image.py | 360 +------- src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_z_image.py | 7 - .../transformer_z_image_control.py | 784 ++++++++++++++++++ .../z_image/pipeline_z_image_controlnet.py | 18 +- 8 files changed, 808 insertions(+), 367 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_z_image_control.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 398f72167ad3..746021bfd706 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -277,6 +277,7 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", + "ZImageControlTransformer2DModel", "ZImageControlNetModel", "attention_backend", ] @@ -1006,6 +1007,7 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageControlNetModel, + ZImageControlTransformer2DModel, ZImageTransformer2DModel, attention_backend, ) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3f8519bbfa32..62182f2d205f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -65,6 +65,7 @@ "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, "ZImageTransformer2DModel": lambda model_cls, weights: weights, + "ZImageControlTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ea15ef2a215..48d06def1c3c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,6 +117,7 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] + _import_structure["transformers.transformer_z_image_control"] = ["ZImageControlTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -231,6 +232,7 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, + ZImageControlTransformer2DModel, ZImageTransformer2DModel, ) from .unets import ( diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index c121f42c1a78..0972fb46c07b 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -12,19 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin from ..transformers.transformer_z_image import ( - SEQ_MULTI_OF, - ZImageTransformer2DModel, ZImageTransformerBlock, ) @@ -47,7 +44,14 @@ def __init__( self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) - def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): + def forward( + self, + c: torch.Tensor, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): if self.block_id == 0: c = self.before_proj(c) + x all_c = [] @@ -55,7 +59,7 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs): all_c = list(torch.unbind(c)) c = all_c.pop(-1) - c = super().forward(c, **kwargs) + c = super().forward(c, attn_mask, freqs_cis, adaln_input) c_skip = self.after_proj(c) all_c += [c_skip, c] c = torch.stack(all_c) @@ -115,345 +119,5 @@ def __init__( ] ) - @staticmethod - def create_coordinate_grid(size, start=None, device=None): - if start is None: - start = (0 for _ in size) - - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] - grids = torch.meshgrid(axes, indexing="ij") - return torch.stack(grids, dim=-1) - - def patchify( - self, - all_image: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - cap_padding_len: int, - ): - pH = pW = patch_size - pF = f_patch_size - device = all_image[0].device - - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - - for i, image in enumerate(all_image): - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - all_image_pos_ids.append(image_padded_pos_ids) - # pad mask - all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) - - return ( - all_image_out, - all_image_size, - all_image_pos_ids, - all_image_pad_mask, - ) - - def patchify_and_embed( - self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - ): - pH = pW = patch_size - pF = f_patch_size - device = all_image[0].device - - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - all_cap_pad_mask.append( - torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - cap_padded_feat = torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, - ) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - all_image_pos_ids.append(image_padded_pos_ids) - # pad mask - all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - ) - # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) - - return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, - all_cap_pos_ids, - all_image_pad_mask, - all_cap_pad_mask, - ) - - def forward( - self, - transformer: ZImageTransformer2DModel, - x: List[torch.Tensor], - cap_feats: List[torch.Tensor], - control_context: List[torch.Tensor], - t=None, - patch_size=2, - f_patch_size=1, - conditioning_scale: float = 1.0, - ): - assert patch_size in self.config.all_patch_size - assert f_patch_size in self.config.all_f_patch_size - - bsz = len(x) - device = x[0].device - t = t * transformer.t_scale - t = transformer.t_embedder(t) - - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) - - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - - # Match t_embedder output dtype to x for layerwise casting compatibility - adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in transformer.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) - else: - for layer in transformer.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) - - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = transformer.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in transformer.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in transformer.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) - - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) - - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 - - ## Original forward_control - - # embeddings - bsz = len(control_context) - device = control_context[0].device - ( - control_context, - x_size, - x_pos_ids, - x_inner_pad_mask, - ) = self.patchify(control_context, patch_size, f_patch_size, cap_feats[0].size(0)) - - # control_context embed & refine - x_item_seqlens = [len(_) for _ in control_context] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - control_context = torch.cat(control_context, dim=0) - control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) - - # Match t_embedder output dtype to control_context for layerwise casting compatibility - adaln_input = t.type_as(control_context) - control_context[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token - control_context = list(control_context.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) - - control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func( - layer, control_context, x_attn_mask, x_freqs_cis, adaln_input - ) - else: - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) - - # unified - cap_item_seqlens = [len(_) for _ in cap_feats] - control_context_unified = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) - control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) - c = control_context_unified - - new_kwargs = { - "x": unified, - "attn_mask": unified_attn_mask, - "freqs_cis": unified_freqs_cis, - "adaln_input": adaln_input, - } - - for layer in self.control_layers: - if torch.is_grad_enabled() and self.gradient_checkpointing: - c = self._gradient_checkpointing_func(layer, c, **new_kwargs) - else: - c = layer(c, **new_kwargs) - - hints = torch.unbind(c)[:-1] - controlnet_block_samples = { - layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) - } - - return controlnet_block_samples + def forward(self): + pass diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a42f6b2716e1..13322aa29ae4 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -48,3 +48,4 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel + from .transformer_z_image_control import ZImageControlTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 70ffced8b63a..7c01361b681d 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -538,7 +538,6 @@ def forward( cap_feats: List[torch.Tensor], patch_size=2, f_patch_size=1, - controlnet_block_samples: Optional[dict[int, torch.Tensor]] = None, return_dict: bool = True, ): assert patch_size in self.all_patch_size @@ -640,15 +639,9 @@ def forward( unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] else: for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py new file mode 100644 index 000000000000..61b752f58f68 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -0,0 +1,784 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_outputs import Transformer2DModelOutput +from .transformer_z_image import ZImageTransformer2DModel + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c): + scale = 1.0 + self.adaLN_modulation(c) + x = self.norm_final(x) * scale.unsqueeze(1) + x = self.linear(x) + return x + + +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) + + +class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] + _repeated_blocks = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] + _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + control_layers_places: List[int] = None, + control_in_dim=None, + ) -> None: + super().__init__() + from ...models.controlnets.controlnet_z_image import ZImageControlTransformerBlock + + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + self.control_layers_places = control_layers_places + self.control_in_dim = control_in_dim + + assert 0 in self.control_layers_places + + # control blocks + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) + for i in self.control_layers_places + ] + ) + + # control patch embeddings + all_x_embedder = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + @classmethod + def from_controlnet( + cls, + transformer: ZImageTransformer2DModel, + controlnet, + load_weights: bool = True, + ): + controlnet.to(device=transformer.device) + + if transformer.config["dim"] != controlnet.config["dim"]: + raise ValueError("Incompatible ControlNet, got a different dim.") + + config = dict(transformer.config) + config["_class_name"] = cls.__name__ + + config["control_layers_places"] = controlnet.config["control_layers_places"] + config["control_in_dim"] = controlnet.config["control_in_dim"] + + expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) + config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) + config["_class_name"] = cls.__name__ + model = cls.from_config(config) + + if not load_weights: + return model + + for i, control_layer in enumerate(controlnet.control_layers): + model.control_layers[i].load_state_dict(control_layer.state_dict()) + + for i, control_all_x_embedder in enumerate(controlnet.control_all_x_embedder): + model.control_all_x_embedder[i].load_state_dict(control_all_x_embedder.state_dict()) + + for i, control_noise_refiner in enumerate(controlnet.control_noise_refiner): + model.control_noise_refiner[i].load_state_dict(control_noise_refiner.state_dict()) + + model.to(transformer.dtype) + + return model + + def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + for i in range(bsz): + F, H, W = size[i] + ori_len = (F // pF) * (H // pH) * (W // pW) + # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" + x[i] = ( + x[i][:ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + return x + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_cap_pad_mask.append( + cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + ) + + # padded feature + cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padded_pos_ids = torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, + ) + all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) + # pad mask + image_pad_mask = torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_image_pad_mask.append( + image_pad_mask + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) + ) + # padded feature + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, + ) + all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + patch_size=2, + f_patch_size=1, + control_context: Optional[List[torch.Tensor]] = None, + conditioning_scale: float = 1.0, + return_dict: bool = True, + ): + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## ControlNet start + + controlnet_block_samples = None + if control_context is not None: + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context_unified = self._gradient_checkpointing_func( + layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + else: + control_context_unified = layer( + control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + + hints = torch.unbind(control_context_unified)[:-1] + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer_idx, layer in enumerate(self.layers): + unified = self._gradient_checkpointing_func( + layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + else: + for layer_idx, layer in enumerate(self.layers): + unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) + unified = list(unified.unbind(dim=0)) + x = self.unpatchify(unified, x_size, patch_size, f_patch_size) + + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index d0460cf09244..2faea94fe134 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -22,7 +22,7 @@ from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets import ZImageControlNetModel -from ...models.transformers import ZImageTransformer2DModel +from ...models.transformers import ZImageControlTransformer2DModel, ZImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -167,10 +167,12 @@ def __init__( vae: AutoencoderKL, text_encoder: PreTrainedModel, tokenizer: AutoTokenizer, - transformer: ZImageTransformer2DModel, + transformer: Union[ZImageControlTransformer2DModel, ZImageTransformer2DModel], controlnet: ZImageControlNetModel, ): super().__init__() + if isinstance(transformer, ZImageTransformer2DModel): + transformer = ZImageControlTransformer2DModel.from_controlnet(transformer, controlnet) self.register_modules( vae=vae, @@ -599,20 +601,12 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - controlnet_block_samples = self.controlnet( - self.transformer, - latent_model_input_list, - prompt_embeds_model_input, - control_image, - timestep_model_input, - conditioning_scale=controlnet_conditioning_scale, - ) - model_out_list = self.transformer( latent_model_input_list, timestep_model_input, prompt_embeds_model_input, - controlnet_block_samples=controlnet_block_samples, + control_context=control_image, + conditioning_scale=controlnet_conditioning_scale, )[0] if apply_cfg: From f63a5a8ddf474ae76e04bbf98f9cb63a7d2f1c24 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:11:44 +0000 Subject: [PATCH 17/38] ModuleDict --- .../models/transformers/transformer_z_image_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 61b752f58f68..90c5bd6c7f24 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -471,8 +471,8 @@ def from_controlnet( for i, control_layer in enumerate(controlnet.control_layers): model.control_layers[i].load_state_dict(control_layer.state_dict()) - for i, control_all_x_embedder in enumerate(controlnet.control_all_x_embedder): - model.control_all_x_embedder[i].load_state_dict(control_all_x_embedder.state_dict()) + for key, control_all_x_embedder in controlnet.control_all_x_embedder.items(): + model.control_all_x_embedder[key].load_state_dict(control_all_x_embedder.state_dict()) for i, control_noise_refiner in enumerate(controlnet.control_noise_refiner): model.control_noise_refiner[i].load_state_dict(control_noise_refiner.state_dict()) From f9540cbb14e02e4498e123ddbe9d0fd6031ed830 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:13:07 +0000 Subject: [PATCH 18/38] patchify control_context --- .../transformer_z_image_control.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 90c5bd6c7f24..9d48bb15cc6a 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -610,6 +610,34 @@ def patchify_and_embed( all_cap_pad_mask, ) + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + all_image_out = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out + def forward( self, x: List[torch.Tensor], @@ -719,6 +747,7 @@ def forward( controlnet_block_samples = None if control_context is not None: + control_context = self.patchify(control_context, patch_size, f_patch_size) control_context = torch.cat(control_context, dim=0) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) From 3e472ac4a43dd7082dbe7d0851a0705afa72f0aa Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:23:36 +0000 Subject: [PATCH 19/38] transformer weights --- .../transformer_z_image_control.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py index 9d48bb15cc6a..55116bb75690 100644 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ b/src/diffusers/models/transformers/transformer_z_image_control.py @@ -468,6 +468,29 @@ def from_controlnet( if not load_weights: return model + for key, all_x_embedder in transformer.all_x_embedder.items(): + model.all_x_embedder[key].load_state_dict(all_x_embedder.state_dict()) + + for key, all_final_layer in transformer.all_final_layer.items(): + model.all_final_layer[key].load_state_dict(all_final_layer.state_dict()) + + for i, noise_refiner in enumerate(transformer.noise_refiner): + model.noise_refiner[i].load_state_dict(noise_refiner.state_dict()) + + for i, context_refiner in enumerate(transformer.context_refiner): + model.context_refiner[i].load_state_dict(context_refiner.state_dict()) + + model.t_embedder.load_state_dict(transformer.t_embedder.state_dict()) + + model.cap_embedder.load_state_dict(transformer.cap_embedder.state_dict()) + + model.x_pad_token = transformer.x_pad_token + + model.cap_pad_token = transformer.cap_pad_token + + for i, layer in enumerate(transformer.layers): + model.layers[i].load_state_dict(layer.state_dict()) + for i, control_layer in enumerate(controlnet.control_layers): model.control_layers[i].load_state_dict(control_layer.state_dict()) From 0e7c643f02b523dd37b3b598e2b16c4c6839f4e9 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 5 Dec 2025 21:30:08 +0000 Subject: [PATCH 20/38] -enumerate in ZImageTransformer2DModel --- src/diffusers/models/transformers/transformer_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 7c01361b681d..5c401b9d202b 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -635,12 +635,12 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer_idx, layer in enumerate(self.layers): + for layer in self.layers: unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) else: - for layer_idx, layer in enumerate(self.layers): + for layer in self.layers: unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) From a00f1048e0a214fd09764cff5dd3afb0300fba1d Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 11 Dec 2025 01:23:06 +0000 Subject: [PATCH 21/38] Option 3 --- src/diffusers/__init__.py | 2 - src/diffusers/loaders/peft.py | 1 - src/diffusers/models/__init__.py | 2 - .../models/controlnets/controlnet_z_image.py | 395 ++++++++- src/diffusers/models/transformers/__init__.py | 1 - .../transformers/transformer_z_image.py | 13 +- .../transformer_z_image_control.py | 836 ------------------ .../z_image/pipeline_z_image_controlnet.py | 18 +- 8 files changed, 412 insertions(+), 856 deletions(-) delete mode 100644 src/diffusers/models/transformers/transformer_z_image_control.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f9a57bdfc85b..f0af3513255a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -279,7 +279,6 @@ "WanTransformer3DModel", "WanVACETransformer3DModel", "ZImageTransformer2DModel", - "ZImageControlTransformer2DModel", "ZImageControlNetModel", "attention_backend", ] @@ -1016,7 +1015,6 @@ WanTransformer3DModel, WanVACETransformer3DModel, ZImageControlNetModel, - ZImageControlTransformer2DModel, ZImageTransformer2DModel, attention_backend, ) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 62182f2d205f..3f8519bbfa32 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -65,7 +65,6 @@ "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, "ZImageTransformer2DModel": lambda model_cls, weights: weights, - "ZImageControlTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 48d06def1c3c..7ea15ef2a215 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -117,7 +117,6 @@ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"] - _import_structure["transformers.transformer_z_image_control"] = ["ZImageControlTransformer2DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -232,7 +231,6 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, - ZImageControlTransformer2DModel, ZImageTransformer2DModel, ) from .unets import ( diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 0972fb46c07b..ea8eaf0a0788 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -16,17 +16,28 @@ import torch import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...models.attention_processor import Attention +from ...models.normalization import RMSNorm +from ...utils.torch_utils import maybe_allow_in_graph from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin from ..transformers.transformer_z_image import ( + ADALN_EMBED_DIM, + SEQ_MULTI_OF, + FeedForward, + RopeEmbedder, + TimestepEmbedder, ZImageTransformerBlock, + ZSingleStreamAttnProcessor, ) -class ZImageControlTransformerBlock(ZImageTransformerBlock): +@maybe_allow_in_graph +class ZImageControlTransformerBlock(nn.Module): def __init__( self, layer_id: int, @@ -38,7 +49,38 @@ def __init__( modulation=True, block_id=0, ): - super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + # Control variant start self.block_id = block_id if block_id == 0: self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) @@ -52,6 +94,7 @@ def forward( freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor] = None, ): + # Control if self.block_id == 0: c = self.before_proj(c) + x all_c = [] @@ -59,7 +102,30 @@ def forward( all_c = list(torch.unbind(c)) c = all_c.pop(-1) - c = super().forward(c, attn_mask, freqs_cis, adaln_input) + # Compared to `ZImageTransformerBlock` x -> c + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + c = c + gate_msa * self.attention_norm2(attn_out) + + # FFN block + c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis) + c = c + self.attention_norm2(attn_out) + + # FFN block + c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c))) + + # Control c_skip = self.after_proj(c) all_c += [c_skip, c] c = torch.stack(all_c) @@ -119,5 +185,324 @@ def __init__( ] ) - def forward(self): - pass + self.t_embedder: Optional[TimestepEmbedder] = None + self.all_x_embedder: Optional[nn.ModuleDict] = None + self.cap_embedder: Optional[nn.Sequential] = None + self.rope_embedder: Optional[RopeEmbedder] = None + self.noise_refiner: Optional[nn.ModuleList] = None + self.context_refiner: Optional[nn.ModuleList] = None + self.x_pad_token: Optional[nn.Parameter] = None + self.cap_pad_token: Optional[nn.Parameter] = None + + @classmethod + def from_transformer(cls, controlnet, transformer): + controlnet.t_embedder = transformer.t_embedder + controlnet.all_x_embedder = transformer.all_x_embedder + controlnet.cap_embedder = transformer.cap_embedder + controlnet.rope_embedder = transformer.rope_embedder + controlnet.noise_refiner = transformer.noise_refiner + controlnet.context_refiner = transformer.context_refiner + controlnet.x_pad_token = transformer.x_pad_token + controlnet.cap_pad_token = transformer.cap_pad_token + return controlnet + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_image: List[torch.Tensor], + all_cap_feats: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + device = all_image[0].device + + all_image_out = [] + all_image_size = [] + all_image_pos_ids = [] + all_image_pad_mask = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_feats_out = [] + + for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): + ### Process Caption + cap_ori_len = len(cap_feat) + cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF + # padded position ids + cap_padded_pos_ids = self.create_coordinate_grid( + size=(cap_ori_len + cap_padding_len, 1, 1), + start=(1, 0, 0), + device=device, + ).flatten(0, 2) + all_cap_pos_ids.append(cap_padded_pos_ids) + # pad mask + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_cap_pad_mask.append( + cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + ) + + # padded feature + cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) + all_cap_feats_out.append(cap_padded_feat) + + ### Process Image + C, F, H, W = image.size() + all_image_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + image_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), + start=(cap_ori_len + cap_padding_len + 1, 0, 0), + device=device, + ).flatten(0, 2) + image_padded_pos_ids = torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, + ) + all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) + # pad mask + image_pad_mask = torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) + all_image_pad_mask.append( + image_pad_mask + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) + ) + # padded feature + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, + ) + all_image_out.append(image_padded_feat if image_padding_len > 0 else image) + + return ( + all_image_out, + all_cap_feats_out, + all_image_size, + all_image_pos_ids, + all_cap_pos_ids, + all_image_pad_mask, + all_cap_pad_mask, + ) + + def patchify( + self, + all_image: List[torch.Tensor], + patch_size: int, + f_patch_size: int, + ): + pH = pW = patch_size + pF = f_patch_size + all_image_out = [] + + for i, image in enumerate(all_image): + ### Process Image + C, F, H, W = image.size() + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + # "c f pf h ph w pw -> (f h w) (pf ph pw c)" + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) + + image_ori_len = len(image) + image_padding_len = (-image_ori_len) % SEQ_MULTI_OF + + # padded feature + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + all_image_out.append(image_padded_feat) + + return all_image_out + + def forward( + self, + x: List[torch.Tensor], + t, + cap_feats: List[torch.Tensor], + control_context: List[torch.Tensor], + conditioning_scale: float = 1.0, + patch_size=2, + f_patch_size=1, + ): + if ( + self.t_embedder is None + or self.all_x_embedder is None + or self.cap_embedder is None + or self.rope_embedder is None + or self.noise_refiner is None + or self.context_refiner is None + or self.x_pad_token is None + or self.cap_pad_token is None + ): + raise ValueError( + "Required modules are `None`, use `from_transformer` to share required modules from `transformer`." + ) + + assert patch_size in self.all_patch_size + assert f_patch_size in self.all_f_patch_size + + bsz = len(x) + device = x[0].device + t = t * self.t_scale + t = self.t_embedder(t) + + ( + x, + cap_feats, + x_size, + x_pos_ids, + cap_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + # Match t_embedder output dtype to x for layerwise casting compatibility + adaln_input = t.type_as(x) + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + else: + for layer in self.noise_refiner: + x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + + # cap embed & refine + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # unified + unified = [] + unified_freqs_cis = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) + unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + assert unified_item_seqlens == [len(_) for _ in unified] + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + ## ControlNet start + control_context = self.patchify(control_context, patch_size, f_patch_size) + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + + # unified + control_context_unified = [] + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) + control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) + + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context_unified = self._gradient_checkpointing_func( + layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + else: + control_context_unified = layer( + control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input + ) + + hints = torch.unbind(control_context_unified)[:-1] + controlnet_block_samples = { + layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) + } + return controlnet_block_samples diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 13322aa29ae4..a42f6b2716e1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -48,4 +48,3 @@ from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel from .transformer_z_image import ZImageTransformer2DModel - from .transformer_z_image_control import ZImageControlTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5c401b9d202b..17197db3a441 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -536,6 +536,7 @@ def forward( x: List[torch.Tensor], t, cap_feats: List[torch.Tensor], + controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None, patch_size=2, f_patch_size=1, return_dict: bool = True, @@ -635,13 +636,19 @@ def forward( unified_attn_mask[i, :seq_len] = 1 if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = self._gradient_checkpointing_func( layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input ) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] else: - for layer in self.layers: + for layer_idx, layer in enumerate(self.layers): unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) + if controlnet_block_samples is not None: + if layer_idx in controlnet_block_samples: + unified = unified + controlnet_block_samples[layer_idx] unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) unified = list(unified.unbind(dim=0)) diff --git a/src/diffusers/models/transformers/transformer_z_image_control.py b/src/diffusers/models/transformers/transformer_z_image_control.py deleted file mode 100644 index 55116bb75690..000000000000 --- a/src/diffusers/models/transformers/transformer_z_image_control.py +++ /dev/null @@ -1,836 +0,0 @@ -# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence - -from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...models.attention_processor import Attention -from ...models.modeling_utils import ModelMixin -from ...models.normalization import RMSNorm -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention_dispatch import dispatch_attention_fn -from ..modeling_outputs import Transformer2DModelOutput -from .transformer_z_image import ZImageTransformer2DModel - - -ADALN_EMBED_DIM = 256 -SEQ_MULTI_OF = 32 - - -class TimestepEmbedder(nn.Module): - def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): - super().__init__() - if mid_size is None: - mid_size = out_size - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, mid_size, bias=True), - nn.SiLU(), - nn.Linear(mid_size, out_size, bias=True), - ) - - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - with torch.amp.autocast("cuda", enabled=False): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - weight_dtype = self.mlp[0].weight.dtype - compute_dtype = getattr(self.mlp[0], "compute_dtype", None) - if weight_dtype.is_floating_point: - t_freq = t_freq.to(weight_dtype) - elif compute_dtype is not None: - t_freq = t_freq.to(compute_dtype) - t_emb = self.mlp(t_freq) - return t_emb - - -class ZSingleStreamAttnProcessor: - """ - Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the - original Z-ImageAttention module. - """ - - _attention_backend = None - _parallel_config = None - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - freqs_cis: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - query = query.unflatten(-1, (attn.heads, -1)) - key = key.unflatten(-1, (attn.heads, -1)) - value = value.unflatten(-1, (attn.heads, -1)) - - # Apply Norms - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE - def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - with torch.amp.autocast("cuda", enabled=False): - x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.unsqueeze(2) - x_out = torch.view_as_real(x * freqs_cis).flatten(3) - return x_out.type_as(x_in) # todo - - if freqs_cis is not None: - query = apply_rotary_emb(query, freqs_cis) - key = apply_rotary_emb(key, freqs_cis) - - # Cast to correct dtype - dtype = query.dtype - query, key = query.to(dtype), key.to(dtype) - - # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] - - # Compute joint attention - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - - # Reshape back - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(dtype) - - output = attn.to_out[0](hidden_states) - if len(attn.to_out) > 1: # dropout - output = attn.to_out[1](output) - - return output - - -class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int): - super().__init__() - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 - - def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) - - -@maybe_allow_in_graph -class ZImageTransformerBlock(nn.Module): - def __init__( - self, - layer_id: int, - dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - qk_norm: bool, - modulation=True, - ): - super().__init__() - self.dim = dim - self.head_dim = dim // n_heads - - # Refactored to use diffusers Attention with custom processor - # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm - self.attention = Attention( - query_dim=dim, - cross_attention_dim=None, - dim_head=dim // n_heads, - heads=n_heads, - qk_norm="rms_norm" if qk_norm else None, - eps=1e-5, - bias=False, - out_bias=False, - processor=ZSingleStreamAttnProcessor(), - ) - - self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) - self.layer_id = layer_id - - self.attention_norm1 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) - - self.attention_norm2 = RMSNorm(dim, eps=norm_eps) - self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) - - self.modulation = modulation - if modulation: - self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) - - def forward( - self, - x: torch.Tensor, - attn_mask: torch.Tensor, - freqs_cis: torch.Tensor, - adaln_input: Optional[torch.Tensor] = None, - ): - if self.modulation: - assert adaln_input is not None - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp - - # Attention block - attn_out = self.attention( - self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis - ) - x = x + gate_msa * self.attention_norm2(attn_out) - - # FFN block - x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) - else: - # Attention block - attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) - x = x + self.attention_norm2(attn_out) - - # FFN block - x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) - - return x - - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), - ) - - def forward(self, x, c): - scale = 1.0 + self.adaLN_modulation(c) - x = self.norm_final(x) * scale.unsqueeze(1) - x = self.linear(x) - return x - - -class RopeEmbedder: - def __init__( - self, - theta: float = 256.0, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (64, 128, 128), - ): - self.theta = theta - self.axes_dims = axes_dims - self.axes_lens = axes_lens - assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" - self.freqs_cis = None - - @staticmethod - def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): - with torch.device("cpu"): - freqs_cis = [] - for i, (d, e) in enumerate(zip(dim, end)): - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) - freqs = torch.outer(timestep, freqs).float() - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 - freqs_cis.append(freqs_cis_i) - - return freqs_cis - - def __call__(self, ids: torch.Tensor): - assert ids.ndim == 2 - assert ids.shape[-1] == len(self.axes_dims) - device = ids.device - - if self.freqs_cis is None: - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - else: - # Ensure freqs_cis are on the same device as ids - if self.freqs_cis[0].device != device: - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] - - result = [] - for i in range(len(self.axes_dims)): - index = ids[:, i] - result.append(self.freqs_cis[i][index]) - return torch.cat(result, dim=-1) - - -class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): - _supports_gradient_checkpointing = True - _no_split_modules = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] - _repeated_blocks = ["ZImageTransformerBlock", "ZImageControlTransformerBlock"] - _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers - - @register_to_config - def __init__( - self, - all_patch_size=(2,), - all_f_patch_size=(1,), - in_channels=16, - dim=3840, - n_layers=30, - n_refiner_layers=2, - n_heads=30, - n_kv_heads=30, - norm_eps=1e-5, - qk_norm=True, - cap_feat_dim=2560, - rope_theta=256.0, - t_scale=1000.0, - axes_dims=[32, 48, 48], - axes_lens=[1024, 512, 512], - control_layers_places: List[int] = None, - control_in_dim=None, - ) -> None: - super().__init__() - from ...models.controlnets.controlnet_z_image import ZImageControlTransformerBlock - - self.in_channels = in_channels - self.out_channels = in_channels - self.all_patch_size = all_patch_size - self.all_f_patch_size = all_f_patch_size - self.dim = dim - self.n_heads = n_heads - - self.rope_theta = rope_theta - self.t_scale = t_scale - self.gradient_checkpointing = False - - assert len(all_patch_size) == len(all_f_patch_size) - - all_x_embedder = {} - all_final_layer = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) - all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer - - self.all_x_embedder = nn.ModuleDict(all_x_embedder) - self.all_final_layer = nn.ModuleDict(all_final_layer) - self.noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.context_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=False, - ) - for layer_id in range(n_refiner_layers) - ] - ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) - self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) - - self.x_pad_token = nn.Parameter(torch.empty((1, dim))) - self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) - - self.layers = nn.ModuleList( - [ - ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) - for layer_id in range(n_layers) - ] - ) - head_dim = dim // n_heads - assert head_dim == sum(axes_dims) - self.axes_dims = axes_dims - self.axes_lens = axes_lens - - self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) - - self.control_layers_places = control_layers_places - self.control_in_dim = control_in_dim - - assert 0 in self.control_layers_places - - # control blocks - self.control_layers = nn.ModuleList( - [ - ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) - for i in self.control_layers_places - ] - ) - - # control patch embeddings - all_x_embedder = {} - for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) - all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - - self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) - self.control_noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) - - @classmethod - def from_controlnet( - cls, - transformer: ZImageTransformer2DModel, - controlnet, - load_weights: bool = True, - ): - controlnet.to(device=transformer.device) - - if transformer.config["dim"] != controlnet.config["dim"]: - raise ValueError("Incompatible ControlNet, got a different dim.") - - config = dict(transformer.config) - config["_class_name"] = cls.__name__ - - config["control_layers_places"] = controlnet.config["control_layers_places"] - config["control_in_dim"] = controlnet.config["control_in_dim"] - - expected_kwargs, optional_kwargs = cls._get_signature_keys(cls) - config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs}) - config["_class_name"] = cls.__name__ - model = cls.from_config(config) - - if not load_weights: - return model - - for key, all_x_embedder in transformer.all_x_embedder.items(): - model.all_x_embedder[key].load_state_dict(all_x_embedder.state_dict()) - - for key, all_final_layer in transformer.all_final_layer.items(): - model.all_final_layer[key].load_state_dict(all_final_layer.state_dict()) - - for i, noise_refiner in enumerate(transformer.noise_refiner): - model.noise_refiner[i].load_state_dict(noise_refiner.state_dict()) - - for i, context_refiner in enumerate(transformer.context_refiner): - model.context_refiner[i].load_state_dict(context_refiner.state_dict()) - - model.t_embedder.load_state_dict(transformer.t_embedder.state_dict()) - - model.cap_embedder.load_state_dict(transformer.cap_embedder.state_dict()) - - model.x_pad_token = transformer.x_pad_token - - model.cap_pad_token = transformer.cap_pad_token - - for i, layer in enumerate(transformer.layers): - model.layers[i].load_state_dict(layer.state_dict()) - - for i, control_layer in enumerate(controlnet.control_layers): - model.control_layers[i].load_state_dict(control_layer.state_dict()) - - for key, control_all_x_embedder in controlnet.control_all_x_embedder.items(): - model.control_all_x_embedder[key].load_state_dict(control_all_x_embedder.state_dict()) - - for i, control_noise_refiner in enumerate(controlnet.control_noise_refiner): - model.control_noise_refiner[i].load_state_dict(control_noise_refiner.state_dict()) - - model.to(transformer.dtype) - - return model - - def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]: - pH = pW = patch_size - pF = f_patch_size - bsz = len(x) - assert len(size) == bsz - for i in range(bsz): - F, H, W = size[i] - ori_len = (F // pF) * (H // pH) * (W // pW) - # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)" - x[i] = ( - x[i][:ori_len] - .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) - .permute(6, 0, 3, 1, 4, 2, 5) - .reshape(self.out_channels, F, H, W) - ) - return x - - @staticmethod - def create_coordinate_grid(size, start=None, device=None): - if start is None: - start = (0 for _ in size) - - axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] - grids = torch.meshgrid(axes, indexing="ij") - return torch.stack(grids, dim=-1) - - def patchify_and_embed( - self, - all_image: List[torch.Tensor], - all_cap_feats: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - ): - pH = pW = patch_size - pF = f_patch_size - device = all_image[0].device - - all_image_out = [] - all_image_size = [] - all_image_pos_ids = [] - all_image_pad_mask = [] - all_cap_pos_ids = [] - all_cap_pad_mask = [] - all_cap_feats_out = [] - - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): - ### Process Caption - cap_ori_len = len(cap_feat) - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF - # padded position ids - cap_padded_pos_ids = self.create_coordinate_grid( - size=(cap_ori_len + cap_padding_len, 1, 1), - start=(1, 0, 0), - device=device, - ).flatten(0, 2) - all_cap_pos_ids.append(cap_padded_pos_ids) - # pad mask - cap_pad_mask = torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_cap_pad_mask.append( - cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) - ) - - # padded feature - cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) - all_cap_feats_out.append(cap_padded_feat) - - ### Process Image - C, F, H, W = image.size() - all_image_size.append((F, H, W)) - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - image_ori_pos_ids = self.create_coordinate_grid( - size=(F_tokens, H_tokens, W_tokens), - start=(cap_ori_len + cap_padding_len + 1, 0, 0), - device=device, - ).flatten(0, 2) - image_padded_pos_ids = torch.cat( - [ - image_ori_pos_ids, - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) - .flatten(0, 2) - .repeat(image_padding_len, 1), - ], - dim=0, - ) - all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) - # pad mask - image_pad_mask = torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - all_image_pad_mask.append( - image_pad_mask - if image_padding_len > 0 - else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) - ) - # padded feature - image_padded_feat = torch.cat( - [image, image[-1:].repeat(image_padding_len, 1)], - dim=0, - ) - all_image_out.append(image_padded_feat if image_padding_len > 0 else image) - - return ( - all_image_out, - all_cap_feats_out, - all_image_size, - all_image_pos_ids, - all_cap_pos_ids, - all_image_pad_mask, - all_cap_pad_mask, - ) - - def patchify( - self, - all_image: List[torch.Tensor], - patch_size: int, - f_patch_size: int, - ): - pH = pW = patch_size - pF = f_patch_size - all_image_out = [] - - for i, image in enumerate(all_image): - ### Process Image - C, F, H, W = image.size() - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW - - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) - - image_ori_len = len(image) - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF - - # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) - - return all_image_out - - def forward( - self, - x: List[torch.Tensor], - t, - cap_feats: List[torch.Tensor], - patch_size=2, - f_patch_size=1, - control_context: Optional[List[torch.Tensor]] = None, - conditioning_scale: float = 1.0, - return_dict: bool = True, - ): - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size - - bsz = len(x) - device = x[0].device - t = t * self.t_scale - t = self.t_embedder(t) - - ( - x, - cap_feats, - x_size, - x_pos_ids, - cap_pos_ids, - x_inner_pad_mask, - cap_inner_pad_mask, - ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) - - # x embed & refine - x_item_seqlens = [len(_) for _ in x] - assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) - x_max_item_seqlen = max(x_item_seqlens) - - x = torch.cat(x, dim=0) - x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) - - # Match t_embedder output dtype to x for layerwise casting compatibility - adaln_input = t.type_as(x) - x[torch.cat(x_inner_pad_mask)] = self.x_pad_token - x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) - - x = pad_sequence(x, batch_first=True, padding_value=0.0) - x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - x_freqs_cis = x_freqs_cis[:, : x.shape[1]] - - x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(x_item_seqlens): - x_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: - x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) - else: - for layer in self.noise_refiner: - x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) - - # cap embed & refine - cap_item_seqlens = [len(_) for _ in cap_feats] - cap_max_item_seqlen = max(cap_item_seqlens) - - cap_feats = torch.cat(cap_feats, dim=0) - cap_feats = self.cap_embedder(cap_feats) - cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token - cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list( - self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) - ) - - cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) - cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) - # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors - cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] - - cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(cap_item_seqlens): - cap_attn_mask[i, :seq_len] = 1 - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.context_refiner: - cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) - else: - for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) - - # unified - unified = [] - unified_freqs_cis = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) - unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) - unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] - assert unified_item_seqlens == [len(_) for _ in unified] - unified_max_item_seqlen = max(unified_item_seqlens) - - unified = pad_sequence(unified, batch_first=True, padding_value=0.0) - unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) - unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_item_seqlens): - unified_attn_mask[i, :seq_len] = 1 - - ## ControlNet start - - controlnet_block_samples = None - if control_context is not None: - control_context = self.patchify(control_context, patch_size, f_patch_size) - control_context = torch.cat(control_context, dim=0) - control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) - - control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token - control_context = list(control_context.split(x_item_seqlens, dim=0)) - - control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func( - layer, control_context, x_attn_mask, x_freqs_cis, adaln_input - ) - else: - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) - - # unified - control_context_unified = [] - for i in range(bsz): - x_len = x_item_seqlens[i] - cap_len = cap_item_seqlens[i] - control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) - control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) - - for layer in self.control_layers: - if torch.is_grad_enabled() and self.gradient_checkpointing: - control_context_unified = self._gradient_checkpointing_func( - layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input - ) - else: - control_context_unified = layer( - control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input - ) - - hints = torch.unbind(control_context_unified)[:-1] - controlnet_block_samples = { - layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) - } - - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer_idx, layer in enumerate(self.layers): - unified = self._gradient_checkpointing_func( - layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input - ) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] - else: - for layer_idx, layer in enumerate(self.layers): - unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input) - if controlnet_block_samples is not None: - if layer_idx in controlnet_block_samples: - unified = unified + controlnet_block_samples[layer_idx] - - unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) - unified = list(unified.unbind(dim=0)) - x = self.unpatchify(unified, x_size, patch_size, f_patch_size) - - if not return_dict: - return (x,) - - return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 2faea94fe134..904259f805ba 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -22,7 +22,7 @@ from ...loaders import FromSingleFileMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnets import ZImageControlNetModel -from ...models.transformers import ZImageControlTransformer2DModel, ZImageTransformer2DModel +from ...models.transformers import ZImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring @@ -167,12 +167,11 @@ def __init__( vae: AutoencoderKL, text_encoder: PreTrainedModel, tokenizer: AutoTokenizer, - transformer: Union[ZImageControlTransformer2DModel, ZImageTransformer2DModel], + transformer: ZImageTransformer2DModel, controlnet: ZImageControlNetModel, ): super().__init__() - if isinstance(transformer, ZImageTransformer2DModel): - transformer = ZImageControlTransformer2DModel.from_controlnet(transformer, controlnet) + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) self.register_modules( vae=vae, @@ -601,12 +600,19 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - model_out_list = self.transformer( + controlnet_block_samples = self.controlnet( latent_model_input_list, timestep_model_input, prompt_embeds_model_input, - control_context=control_image, + control_image, conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, )[0] if apply_cfg: From a961402959395d23a666a49b067bbdb12de48727 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 11 Dec 2025 03:36:40 +0000 Subject: [PATCH 22/38] from_single_file --- src/diffusers/loaders/single_file_model.py | 8 ++++++++ src/diffusers/loaders/single_file_utils.py | 19 +++++++++++++++++++ .../models/controlnets/controlnet_z_image.py | 3 ++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 803fdfc2d952..57f1594e5e87 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -53,6 +53,7 @@ create_controlnet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm, create_vae_diffusers_config_from_ldm, + create_z_image_controlnet_config, fetch_diffusers_config, fetch_original_config, load_single_file_checkpoint, @@ -172,6 +173,10 @@ "checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "ZImageControlNetModel": { + "checkpoint_mapping_fn": lambda x: x, + "config_create_fn": create_z_image_controlnet_config, + }, } @@ -369,6 +374,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = diffusers_model_config = config_mapping_fn( original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs ) + elif "config_create_fn" in mapping_functions: + config_create_fn = mapping_functions["config_create_fn"] + diffusers_model_config = config_create_fn() else: if config is not None: if isinstance(config, str): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index b866a5a21ae3..6f24b0f848ec 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -121,6 +121,7 @@ "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], "z-image-turbo": "cap_embedder.0.weight", + "z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight", "sana": [ "blocks.0.cross_attn.q_linear.weight", "blocks.0.cross_attn.q_linear.bias", @@ -779,6 +780,9 @@ def infer_diffusers_model_type(checkpoint): else: raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.") + elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint: + model_type = "z-image-turbo-controlnet" + else: model_type = "v1" @@ -3885,3 +3889,18 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) handler_fn_inplace(key, converted_state_dict) return converted_state_dict + + +def create_z_image_controlnet_config(): + return { + "all_f_patch_size": [1], + "all_patch_size": [2], + "control_in_dim": 16, + "control_layers_places": [0, 5, 10, 15, 20, 25], + "dim": 3840, + "n_heads": 30, + "n_kv_heads": 30, + "n_refiner_layers": 2, + "norm_eps": 1e-05, + "qk_norm": True, + } diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index ea8eaf0a0788..2e518fd46f0a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -20,6 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...models.attention_processor import Attention from ...models.normalization import RMSNorm from ...utils.torch_utils import maybe_allow_in_graph @@ -132,7 +133,7 @@ def forward( return c -class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True @register_to_config From 6e1c2189d005ec036fd9ffe06c13878b4d18edc6 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 11 Dec 2025 03:42:05 +0000 Subject: [PATCH 23/38] Remove convert script --- ...convert_z_image_controlnet_to_diffusers.py | 70 ------------------- 1 file changed, 70 deletions(-) delete mode 100644 scripts/convert_z_image_controlnet_to_diffusers.py diff --git a/scripts/convert_z_image_controlnet_to_diffusers.py b/scripts/convert_z_image_controlnet_to_diffusers.py deleted file mode 100644 index e5d5f34e36e8..000000000000 --- a/scripts/convert_z_image_controlnet_to_diffusers.py +++ /dev/null @@ -1,70 +0,0 @@ -import argparse -from contextlib import nullcontext - -import safetensors.torch -import torch -from accelerate import init_empty_weights -from huggingface_hub import hf_hub_download - -from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel -from diffusers.utils.import_utils import is_accelerate_available - - -""" -python scripts/convert_z_image_controlnet_to_diffusers.py \ ---original_controlnet_repo_id "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union" \ ---filename "Z-Image-Turbo-Fun-Controlnet-Union.safetensors" ---output_path "z-image-controlnet-hf/" -""" - - -CTX = init_empty_weights if is_accelerate_available else nullcontext - -parser = argparse.ArgumentParser() -parser.add_argument("--original_controlnet_repo_id", default=None, type=str) -parser.add_argument("--filename", default="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", type=str) -parser.add_argument("--checkpoint_path", default=None, type=str) -parser.add_argument("--output_path", type=str) - -args = parser.parse_args() - - -def load_original_checkpoint(args): - if args.original_controlnet_repo_id is not None: - ckpt_path = hf_hub_download(repo_id=args.original_controlnet_repo_id, filename=args.filename) - elif args.checkpoint_path is not None: - ckpt_path = args.checkpoint_path - else: - raise ValueError(" please provide either `original_controlnet_repo_id` or a local `checkpoint_path`") - - original_state_dict = safetensors.torch.load_file(ckpt_path) - return original_state_dict - - -def convert_z_image_controlnet_checkpoint_to_diffusers(original_state_dict): - converted_state_dict = {} - - converted_state_dict.update(original_state_dict) - - return converted_state_dict - - -def main(args): - original_ckpt = load_original_checkpoint(args) - - control_in_dim = 16 - control_layers_places = [0, 5, 10, 15, 20, 25] - - converted_controlnet_state_dict = convert_z_image_controlnet_checkpoint_to_diffusers(original_ckpt) - - controlnet = ZImageControlNetModel( - control_layers_places=control_layers_places, - control_in_dim=control_in_dim, - ).to(torch.bfloat16) - controlnet.load_state_dict(converted_controlnet_state_dict) - print("Saving Z-Image ControlNet in Diffusers format") - controlnet.save_pretrained(args.output_path) - - -if __name__ == "__main__": - main(args) From 8e7743a44c410508ad7187819ab5049abed9d7b8 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 11 Dec 2025 21:09:48 +0000 Subject: [PATCH 24/38] Copied from --- .../models/controlnets/controlnet_z_image.py | 276 +++++++++++++++++- 1 file changed, 267 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 2e518fd46f0a..5b0c0d363f4a 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional import torch import torch.nn as nn +import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import ConfigMixin, register_to_config @@ -24,17 +26,271 @@ from ...models.attention_processor import Attention from ...models.normalization import RMSNorm from ...utils.torch_utils import maybe_allow_in_graph +from ..attention_dispatch import dispatch_attention_fn from ..controlnets.controlnet import zero_module from ..modeling_utils import ModelMixin -from ..transformers.transformer_z_image import ( - ADALN_EMBED_DIM, - SEQ_MULTI_OF, - FeedForward, - RopeEmbedder, - TimestepEmbedder, - ZImageTransformerBlock, - ZSingleStreamAttnProcessor, -) + + +ADALN_EMBED_DIM = 256 +SEQ_MULTI_OF = 32 + + +# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder +class TimestepEmbedder(nn.Module): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + super().__init__() + if mid_size is None: + mid_size = out_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, mid_size, bias=True), + nn.SiLU(), + nn.Linear(mid_size, out_size, bias=True), + ) + + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + with torch.amp.autocast("cuda", enabled=False): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) + if weight_dtype.is_floating_point: + t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +# Copied from diffusers.models.transformers.transformer_z_image.FeedForward +class FeedForward(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def _forward_silu_gating(self, x1, x3): + return F.silu(x1) * x3 + + def forward(self, x): + return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + + +@maybe_allow_in_graph +# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + ): + if self.modulation: + assert adaln_input is not None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder +class RopeEmbedder: + def __init__( + self, + theta: float = 256.0, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (64, 128, 128), + ): + self.theta = theta + self.axes_dims = axes_dims + self.axes_lens = axes_lens + assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" + self.freqs_cis = None + + @staticmethod + def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): + with torch.device("cpu"): + freqs_cis = [] + for i, (d, e) in enumerate(zip(dim, end)): + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) + freqs = torch.outer(timestep, freqs).float() + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 + freqs_cis.append(freqs_cis_i) + + return freqs_cis + + def __call__(self, ids: torch.Tensor): + assert ids.ndim == 2 + assert ids.shape[-1] == len(self.axes_dims) + device = ids.device + + if self.freqs_cis is None: + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + else: + # Ensure freqs_cis are on the same device as ids + if self.freqs_cis[0].device != device: + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] + + result = [] + for i in range(len(self.axes_dims)): + index = ids[:, i] + result.append(self.freqs_cis[i][index]) + return torch.cat(result, dim=-1) @maybe_allow_in_graph @@ -208,6 +464,7 @@ def from_transformer(cls, controlnet, transformer): return controlnet @staticmethod + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid def create_coordinate_grid(size, start=None, device=None): if start is None: start = (0 for _ in size) @@ -216,6 +473,7 @@ def create_coordinate_grid(size, start=None, device=None): grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed def patchify_and_embed( self, all_image: List[torch.Tensor], From c13517029100fa285e245fd5f553c430f77f76fc Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 11 Dec 2025 21:42:20 +0000 Subject: [PATCH 25/38] Example --- .../z_image/pipeline_z_image_controlnet.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 904259f805ba..33dfc1996c1f 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -39,11 +39,17 @@ >>> from diffusers import ZImageControlNetPipeline >>> from diffusers import ZImageControlNetModel >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download - >>> controlnet_model = "..." - >>> controlnet = ZImageControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) - >>> pipe = ZImageControlNetPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) + >>> pipe = ZImageControlNetPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. @@ -52,16 +58,17 @@ >>> # (2) Use flash attention 3 >>> # pipe.transformer.set_attention_backend("_flash_3") - >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg") - >>> prompt = "A girl in city, 25 years old, cool, futuristic" + >>> control_image = load_image("https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true") + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。" >>> image = pipe( ... prompt, ... control_image=control_image, - ... height=1024, - ... width=1024, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, ... num_inference_steps=9, ... guidance_scale=0.0, - ... generator=torch.Generator("cuda").manual_seed(42), + ... generator=torch.Generator("cuda").manual_seed(43), ... ).images[0] >>> image.save("zimage.png") ``` From a737b3cc1e139fb234a960a130fb1e9cc6dd0d92 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 12 Dec 2025 18:33:01 +0000 Subject: [PATCH 26/38] doc-builder style --- .../pipelines/z_image/pipeline_z_image_controlnet.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 33dfc1996c1f..51b2c9bdd445 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -49,7 +49,9 @@ ... torch_dtype=torch.bfloat16, ... ) - >>> pipe = ZImageControlNetPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16) + >>> pipe = ZImageControlNetPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) >>> pipe.to("cuda") >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. @@ -58,7 +60,9 @@ >>> # (2) Use flash attention 3 >>> # pipe.transformer.set_attention_backend("_flash_3") - >>> control_image = load_image("https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true") + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/asset/pose.jpg?download=true" + ... ) >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。" >>> image = pipe( ... prompt, From 7bc847ac389166c79efaaf64378e4b1d0590248b Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 12 Dec 2025 18:33:13 +0000 Subject: [PATCH 27/38] check_dummies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 8628893200fe..d8ff309c186c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1762,6 +1762,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ZImageControlNetModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ZImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ff65372f3c89..e0388e6c143d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3812,6 +3812,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ZImageControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ZImageImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From ffde03532240b0808904068c35199c4ac5185cb1 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 12 Dec 2025 18:33:23 +0000 Subject: [PATCH 28/38] custom_init_isort --- src/diffusers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f0af3513255a..a66b81ab408e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -278,8 +278,8 @@ "WanAnimateTransformer3DModel", "WanTransformer3DModel", "WanVACETransformer3DModel", - "ZImageTransformer2DModel", "ZImageControlNetModel", + "ZImageTransformer2DModel", "attention_backend", ] ) @@ -667,9 +667,9 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", + "ZImageControlNetPipeline", "ZImageImg2ImgPipeline", "ZImagePipeline", - "ZImageControlNetPipeline", ] ) From 04388f4698b303785b26bec6179a55aea652a388 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 10:56:49 +0000 Subject: [PATCH 29/38] init v2 --- src/diffusers/loaders/single_file_model.py | 3 +- src/diffusers/loaders/single_file_utils.py | 27 ++++- .../models/controlnets/controlnet_z_image.py | 101 +++++++++++++----- 3 files changed, 98 insertions(+), 33 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 57f1594e5e87..63b31b34529f 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -376,7 +376,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) elif "config_create_fn" in mapping_functions: config_create_fn = mapping_functions["config_create_fn"] - diffusers_model_config = config_create_fn() + config_create_kwargs = _get_mapping_function_kwargs(config_create_fn, **kwargs) + diffusers_model_config = config_create_fn(checkpoint=checkpoint, **config_create_kwargs) else: if config is not None: if isinstance(config, str): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 6f24b0f848ec..20ab5d827255 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -3891,10 +3891,8 @@ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) return converted_state_dict -def create_z_image_controlnet_config(): - return { - "all_f_patch_size": [1], - "all_patch_size": [2], +def create_z_image_controlnet_config(checkpoint, **kwargs): + v1_config = { "control_in_dim": 16, "control_layers_places": [0, 5, 10, 15, 20, 25], "dim": 3840, @@ -3903,4 +3901,25 @@ def create_z_image_controlnet_config(): "n_refiner_layers": 2, "norm_eps": 1e-05, "qk_norm": True, + "all_f_patch_size": [1], + "all_patch_size": [2], + } + v2_config = { + "control_in_dim": 33, + "control_layers_places": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], + "control_refiner_layers_places": [0, 1], + "add_control_noise_refiner": True, + "dim": 3840, + "n_heads": 30, + "n_kv_heads": 30, + "n_refiner_layers": 2, + "norm_eps": 1e-05, + "qk_norm": True, + "all_f_patch_size": [1], + "all_patch_size": [2], } + control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1] + if control_x_embedder_weight_shape == 64: + return v1_config + elif control_x_embedder_weight_shape == 132: + return v2_config diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 5b0c0d363f4a..f7c7c4033ae4 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -395,6 +395,10 @@ class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi @register_to_config def __init__( self, + control_layers_places: List[int] = None, + control_refiner_layers_places: List[int] = None, + control_in_dim=None, + add_control_noise_refiner=False, all_patch_size=(2,), all_f_patch_size=(1,), dim=3840, @@ -403,12 +407,12 @@ def __init__( n_kv_heads=30, norm_eps=1e-5, qk_norm=True, - control_layers_places: List[int] = None, - control_in_dim=None, ): super().__init__() self.control_layers_places = control_layers_places self.control_in_dim = control_in_dim + self.control_refiner_layers_places = control_refiner_layers_places + self.add_control_noise_refiner = add_control_noise_refiner assert 0 in self.control_layers_places @@ -427,20 +431,37 @@ def __init__( all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) - self.control_noise_refiner = nn.ModuleList( - [ - ZImageTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - ) - for layer_id in range(n_refiner_layers) - ] - ) + if self.add_control_noise_refiner: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageControlTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + block_id=layer_id, + ) + for layer_id in range(n_refiner_layers) + ] + ) + else: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) self.t_embedder: Optional[TimestepEmbedder] = None self.all_x_embedder: Optional[nn.ModuleDict] = None @@ -647,11 +668,20 @@ def forward( cap_inner_pad_mask, ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) - # x embed & refine x_item_seqlens = [len(_) for _ in x] assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) x_max_item_seqlen = max(x_item_seqlens) + control_context = self.patchify(control_context, patch_size, f_patch_size) + control_context = torch.cat(control_context, dim=0) + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) + + control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token + control_context = list(control_context.split(x_item_seqlens, dim=0)) + + control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) + + # x embed & refine x = torch.cat(x, dim=0) x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) @@ -670,12 +700,36 @@ def forward( for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 + if self.add_control_noise_refiner: + for layer in self.control_layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + control_context = self._gradient_checkpointing_func( + layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input) + + hints = torch.unbind(control_context)[:-1] + control_context = torch.unbind(control_context)[-1] + noise_refiner_block_samples = { + layer_idx: hints[idx] * conditioning_scale + for idx, layer_idx in enumerate(self.control_refiner_layers_places) + } + else: + noise_refiner_block_samples = None + if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.noise_refiner: + for layer_idx, layer in enumerate(self.noise_refiner): x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) + if noise_refiner_block_samples is not None: + if layer_idx in noise_refiner_block_samples: + x = x + noise_refiner_block_samples[layer_idx] else: - for layer in self.noise_refiner: + for layer_idx, layer in enumerate(self.noise_refiner): x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) + if noise_refiner_block_samples is not None: + if layer_idx in noise_refiner_block_samples: + x = x + noise_refiner_block_samples[layer_idx] # cap embed & refine cap_item_seqlens = [len(_) for _ in cap_feats] @@ -724,15 +778,6 @@ def forward( unified_attn_mask[i, :seq_len] = 1 ## ControlNet start - control_context = self.patchify(control_context, patch_size, f_patch_size) - control_context = torch.cat(control_context, dim=0) - control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) - - control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token - control_context = list(control_context.split(x_item_seqlens, dim=0)) - - control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) - if torch.is_grad_enabled() and self.gradient_checkpointing: for layer in self.control_noise_refiner: control_context = self._gradient_checkpointing_func( From 62ee1c1d61a8beed7172d9746247b37b125b4c64 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 12:03:46 +0000 Subject: [PATCH 30/38] handle 2.0 t2i pipeline control_image dimension --- .../z_image/pipeline_z_image_controlnet.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py index 51b2c9bdd445..62e26312ff66 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py @@ -533,6 +533,20 @@ def __call__( control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor control_image = control_image.unsqueeze(2) + if num_channels_latents != self.controlnet.config.control_in_dim: + # For model version 2.0 + control_image = torch.cat( + [ + control_image, + torch.zeros( + control_image.shape[0], + self.controlnet.config.control_in_dim - num_channels_latents, + *control_image.shape[2:], + ).to(device=control_image.device, dtype=control_image.dtype), + ], + dim=1, + ) + latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, From faf5a24c50c691a865bec44d9b7d84d26f4a8603 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 12:28:43 +0000 Subject: [PATCH 31/38] ZImageControlNetInpaintPipeline --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 14 +- src/diffusers/pipelines/z_image/__init__.py | 2 + .../pipeline_z_image_controlnet_inpaint.py | 736 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + 5 files changed, 767 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a66b81ab408e..7a94c75e8a1e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -668,6 +668,7 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", "ZImageControlNetPipeline", + "ZImageControlNetInpaintPipeline", "ZImageImg2ImgPipeline", "ZImagePipeline", ] @@ -1372,6 +1373,7 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, + ZImageControlNetInpaintPipeline, ZImageControlNetPipeline, ZImageImg2ImgPipeline, ZImagePipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 41489c183f79..6b5ddbcf28b4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -404,7 +404,12 @@ "Kandinsky5T2IPipeline", "Kandinsky5I2IPipeline", ] - _import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline", "ZImageControlNetPipeline"] + _import_structure["z_image"] = [ + "ZImageImg2ImgPipeline", + "ZImagePipeline", + "ZImageControlNetPipeline", + "ZImageControlNetInpaintPipeline", + ] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -841,7 +846,12 @@ WuerstchenDecoderPipeline, WuerstchenPriorPipeline, ) - from .z_image import ZImageControlNetPipeline, ZImageImg2ImgPipeline, ZImagePipeline + from .z_image import ( + ZImageControlNetInpaintPipeline, + ZImageControlNetPipeline, + ZImageImg2ImgPipeline, + ZImagePipeline, + ) try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py index 4bb1900edb14..7b3cfbceea2c 100644 --- a/src/diffusers/pipelines/z_image/__init__.py +++ b/src/diffusers/pipelines/z_image/__init__.py @@ -24,6 +24,7 @@ _import_structure["pipeline_output"] = ["ZImagePipelineOutput"] _import_structure["pipeline_z_image"] = ["ZImagePipeline"] _import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"] + _import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"] _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"] @@ -38,6 +39,7 @@ from .pipeline_output import ZImagePipelineOutput from .pipeline_z_image import ZImagePipeline from .pipeline_z_image_controlnet import ZImageControlNetPipeline + from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline from .pipeline_z_image_img2img import ZImageImg2ImgPipeline else: diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py new file mode 100644 index 000000000000..d18172083cef --- /dev/null +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -0,0 +1,736 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, PreTrainedModel + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKL +from ...models.controlnets import ZImageControlNetModel +from ...models.transformers import ZImageTransformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import ZImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageControlNetInpaintPipeline + >>> from diffusers import ZImageControlNetModel + >>> from diffusers.utils import load_image + >>> from huggingface_hub import hf_hub_download + + >>> controlnet = ZImageControlNetModel.from_single_file( + ... hf_hub_download( + ... "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0", + ... filename="Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors", + ... ), + ... torch_dtype=torch.bfloat16, + ... ) + + >>> pipe = ZImageControlNetInpaintPipeline.from_pretrained( + ... "Tongyi-MAI/Z-Image-Turbo", controlnet=controlnet, torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/inpaint.jpg?download=true" + ... ) + >>> mask_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/mask.jpg?download=true" + ... ) + >>> control_image = load_image( + ... "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0/resolve/main/asset/pose.jpg?download=true" + ... ) + >>> prompt = "一位年轻女子站在阳光明媚的海岸线上,画面为全身竖构图,身体微微侧向右侧,左手自然下垂,右臂弯曲扶在腰间,她的手指清晰可见,站姿放松而略带羞涩。她身穿轻盈的白色连衣裙,裙摆在海风中轻轻飘动,布料半透、质感柔软。女子拥有一头鲜艳的及腰紫色长发,被海风吹起,在身侧轻盈飞舞,发间系着一个精致的黑色蝴蝶结,与发色形成对比。她面容清秀,眉目精致,肤色白皙细腻,表情温柔略显羞涩,微微低头,眼神静静望向远处的海平线,流露出甜美的青春气息与若有所思的神情。背景是辽阔无垠的海洋与蔚蓝天空,阳光从侧前方洒下,海面波光粼粼,泛着温暖的金色光晕,天空清澈明亮,云朵稀薄,整体色调清新唯美。" + >>> image = pipe( + ... prompt, + ... image=image, + ... mask_image=mask_image, + ... control_image=control_image, + ... controlnet_conditioning_scale=0.75, + ... height=1728, + ... width=992, + ... num_inference_steps=25, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(43), + ... ).images[0] + >>> image.save("zimage-inpaint.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + controlnet: ZImageControlNetModel, + ): + super().__init__() + if self.transformer.in_channels == self.controlnet.config.control_in_dim: + raise ValueError( + "ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`." + ) + controlnet = ZImageControlNetModel.from_transformer(controlnet, transformer) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + controlnet=controlnet, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.75, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = control_image.shape[-2:] + control_image = retrieve_latents(self.vae.encode(control_image), generator=generator, sample_mode="argmax") + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + control_image = control_image.unsqueeze(2) + + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to( + device=control_image.device, dtype=control_image.dtype + ) + + init_image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + ) + height, width = init_image.shape[-2:] + init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax") + init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + init_image = init_image.unsqueeze(2) + + mask_condition = F.interpolate(1 - mask_condition[:, :1], size=init_image.size()[-2:], mode="nearest").to( + device=control_image.device, dtype=control_image.dtype + ) + mask_condition = mask_condition.unsqueeze(2) + + control_image = torch.cat([control_image, mask_condition, init_image], dim=1) + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + controlnet_block_samples = self.controlnet( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_image, + conditioning_scale=controlnet_conditioning_scale, + ) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + controlnet_block_samples=controlnet_block_samples, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e0388e6c143d..e8ce55e4d5c0 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -3812,6 +3812,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ZImageControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ZImageControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 6126f02da270829963e2b5c8a5190d09e59bf1ab Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 12:48:35 +0000 Subject: [PATCH 32/38] t_scale --- src/diffusers/models/controlnets/controlnet_z_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index f7c7c4033ae4..ccb09c8620ed 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -474,6 +474,7 @@ def __init__( @classmethod def from_transformer(cls, controlnet, transformer): + controlnet.t_scale = transformer.t_scale controlnet.t_embedder = transformer.t_embedder controlnet.all_x_embedder = transformer.all_x_embedder controlnet.cap_embedder = transformer.cap_embedder From c3def6b5f2c91d6d571a45c9054b5147e3b797d9 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 12:49:32 +0000 Subject: [PATCH 33/38] config.all_patch_size --- src/diffusers/models/controlnets/controlnet_z_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index ccb09c8620ed..4d984802a3d4 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -651,8 +651,8 @@ def forward( "Required modules are `None`, use `from_transformer` to share required modules from `transformer`." ) - assert patch_size in self.all_patch_size - assert f_patch_size in self.all_f_patch_size + assert patch_size in self.config.all_patch_size + assert f_patch_size in self.config.all_f_patch_size bsz = len(x) device = x[0].device From f80ed52fcc3372a1590a0e60011d0a13a7c59524 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 13:03:00 +0000 Subject: [PATCH 34/38] not self.add_control_noise_refiner --- .../models/controlnets/controlnet_z_image.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index 4d984802a3d4..bc77d2a60bf6 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -779,14 +779,15 @@ def forward( unified_attn_mask[i, :seq_len] = 1 ## ControlNet start - if torch.is_grad_enabled() and self.gradient_checkpointing: - for layer in self.control_noise_refiner: - control_context = self._gradient_checkpointing_func( - layer, control_context, x_attn_mask, x_freqs_cis, adaln_input - ) - else: - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) + if not self.add_control_noise_refiner: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.control_noise_refiner: + control_context = self._gradient_checkpointing_func( + layer, control_context, x_attn_mask, x_freqs_cis, adaln_input + ) + else: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) # unified control_context_unified = [] From 721011e52b33d953334c87ef8bf5545bd6987669 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 13:03:24 +0000 Subject: [PATCH 35/38] -self --- .../pipelines/z_image/pipeline_z_image_controlnet_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py index d18172083cef..156b09a571c3 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -191,7 +191,7 @@ def __init__( controlnet: ZImageControlNetModel, ): super().__init__() - if self.transformer.in_channels == self.controlnet.config.control_in_dim: + if transformer.in_channels == controlnet.config.control_in_dim: raise ValueError( "ZImageControlNetInpaintPipeline is not compatible with `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union`, use `alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0`." ) From efadd91f8b4989ed2af71b7cf46bd2be4fff44e4 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 13:27:47 +0000 Subject: [PATCH 36/38] * (mask_condition < 0.5) --- .../pipelines/z_image/pipeline_z_image_controlnet_inpaint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py index 156b09a571c3..eab17ae37237 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py @@ -566,6 +566,7 @@ def __call__( dtype=self.vae.dtype, ) height, width = init_image.shape[-2:] + init_image = init_image * (mask_condition < 0.5) init_image = retrieve_latents(self.vae.encode(init_image), generator=generator, sample_mode="argmax") init_image = (init_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor init_image = init_image.unsqueeze(2) From f4b7fcc67daefd2ac9762239a9fe0900aab95c89 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 14:09:01 +0000 Subject: [PATCH 37/38] create_z_image_controlnet_config unknown type --- src/diffusers/loaders/single_file_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 20ab5d827255..ed39c0d4837f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -3923,3 +3923,5 @@ def create_z_image_controlnet_config(checkpoint, **kwargs): return v1_config elif control_x_embedder_weight_shape == 132: return v2_config + else: + raise ValueError("Unknown Z-Image Turbo ControlNet type.") From dd9775caf0906c6727fcbe2797cf6dc60cc38f45 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 13 Dec 2025 14:32:06 +0000 Subject: [PATCH 38/38] pop control_noise_refiner from 2.0 state_dict --- src/diffusers/loaders/single_file_model.py | 9 +++++++-- src/diffusers/loaders/single_file_utils.py | 13 +++++++++++++ .../models/controlnets/controlnet_z_image.py | 16 +--------------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 63b31b34529f..43abae4208ce 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -49,6 +49,7 @@ convert_stable_cascade_unet_single_file_to_diffusers, convert_wan_transformer_to_diffusers, convert_wan_vae_to_diffusers, + convert_z_image_controlnet_checkpoint_to_diffusers, convert_z_image_transformer_checkpoint_to_diffusers, create_controlnet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm, @@ -174,14 +175,18 @@ "default_subfolder": "transformer", }, "ZImageControlNetModel": { - "checkpoint_mapping_fn": lambda x: x, + "checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers, "config_create_fn": create_z_image_controlnet_config, }, } def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): - return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) + model_state_dict_keys = set(model_state_dict.keys()) + checkpoint_state_dict_keys = set(checkpoint_state_dict.keys()) + is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys) + is_match = model_state_dict_keys == checkpoint_state_dict_keys + return not (is_subset and is_match) def _get_single_file_loadable_mapping_class(cls): diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ed39c0d4837f..ed81ae752560 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -3925,3 +3925,16 @@ def create_z_image_controlnet_config(checkpoint, **kwargs): return v2_config else: raise ValueError("Unknown Z-Image Turbo ControlNet type.") + + +def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, **kwargs): + control_x_embedder_weight_shape = checkpoint["control_all_x_embedder.2-1.weight"].shape[1] + if control_x_embedder_weight_shape == 64: + return checkpoint + elif control_x_embedder_weight_shape == 132: + converted_state_dict = { + key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.") + } + return converted_state_dict + else: + raise ValueError("Unknown Z-Image Turbo ControlNet type.") diff --git a/src/diffusers/models/controlnets/controlnet_z_image.py b/src/diffusers/models/controlnets/controlnet_z_image.py index bc77d2a60bf6..1e628e560e8e 100644 --- a/src/diffusers/models/controlnets/controlnet_z_image.py +++ b/src/diffusers/models/controlnets/controlnet_z_image.py @@ -432,21 +432,7 @@ def __init__( self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) if self.add_control_noise_refiner: - self.control_noise_refiner = nn.ModuleList( - [ - ZImageControlTransformerBlock( - 1000 + layer_id, - dim, - n_heads, - n_kv_heads, - norm_eps, - qk_norm, - modulation=True, - block_id=layer_id, - ) - for layer_id in range(n_refiner_layers) - ] - ) + self.control_noise_refiner = None else: self.control_noise_refiner = nn.ModuleList( [