+
+## Popular libraries using 🧨 Diffusers
+
+- https://github.com/microsoft/TaskMatrix
+- https://github.com/invoke-ai/InvokeAI
+- https://github.com/InstantID/InstantID
+- https://github.com/apple/ml-stable-diffusion
+- https://github.com/Sanster/lama-cleaner
+- https://github.com/IDEA-Research/Grounded-Segment-Anything
+- https://github.com/ashawkey/stable-dreamfusion
+- https://github.com/deep-floyd/IF
+- https://github.com/bentoml/BentoML
+- https://github.com/bmaltais/kohya_ss
+- +14,000 other amazing GitHub repositories 💪
+
+Thank you for using us ❤️.
+
+## Credits
+
+This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today:
+
+- @CompVis' latent diffusion models library, available [here](https://github.com/CompVis/latent-diffusion)
+- @hojonathanho original DDPM implementation, available [here](https://github.com/hojonathanho/diffusion) as well as the extremely useful translation into PyTorch by @pesser, available [here](https://github.com/pesser/pytorch_diffusion)
+- @ermongroup's DDIM implementation, available [here](https://github.com/ermongroup/ddim)
+- @yang-song's Score-VE and Score-VP implementations, available [here](https://github.com/yang-song/score_sde_pytorch)
+
+We also want to thank @heejkoo for the very helpful overview of papers, code and resources on diffusion models, available [here](https://github.com/heejkoo/Awesome-Diffusion-Models) as well as @crowsonkb and @rromb for useful discussions and insights.
+
+## Citation
+
+```bibtex
+@misc{von-platen-etal-2022-diffusers,
+ author = {Patrick von Platen and Suraj Patil and Anton Lozhkov and Pedro Cuenca and Nathan Lambert and Kashif Rasul and Mishig Davaadorj and Dhruv Nair and Sayak Paul and William Berman and Yiyi Xu and Steven Liu and Thomas Wolf},
+ title = {Diffusers: State-of-the-art diffusion models},
+ year = {2022},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/huggingface/diffusers}}
+}
+```
diff --git a/diffusers/_typos.toml b/diffusers/_typos.toml
new file mode 100755
index 0000000..551099f
--- /dev/null
+++ b/diffusers/_typos.toml
@@ -0,0 +1,13 @@
+# Files for typos
+# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
+
+[default.extend-identifiers]
+
+[default.extend-words]
+NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
+nd="np" # nd may be np (numpy)
+parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py
+
+
+[files]
+extend-exclude = ["_typos.toml"]
diff --git a/diffusers/pyproject.toml b/diffusers/pyproject.toml
new file mode 100755
index 0000000..299865a
--- /dev/null
+++ b/diffusers/pyproject.toml
@@ -0,0 +1,29 @@
+[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
+# Never enforce `E501` (line length violations).
+ignore = ["C901", "E501", "E741", "F402", "F823"]
+select = ["C", "E", "F", "I", "W"]
+
+# Ignore import violations in all `__init__.py` files.
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+"src/diffusers/utils/dummy_*.py" = ["F401"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["diffusers"]
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/diffusers/scripts/__init__.py b/diffusers/scripts/__init__.py
new file mode 100755
index 0000000..e69de29
diff --git a/diffusers/scripts/change_naming_configs_and_checkpoints.py b/diffusers/scripts/change_naming_configs_and_checkpoints.py
new file mode 100755
index 0000000..adc1605
--- /dev/null
+++ b/diffusers/scripts/change_naming_configs_and_checkpoints.py
@@ -0,0 +1,113 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for the LDM checkpoints."""
+
+import argparse
+import json
+import os
+
+import torch
+from transformers.file_utils import has_file
+
+from diffusers import UNet2DConditionModel, UNet2DModel
+
+
+do_only_config = False
+do_only_weights = True
+do_only_renaming = False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--repo_path",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+
+ config_parameters_to_change = {
+ "image_size": "sample_size",
+ "num_res_blocks": "layers_per_block",
+ "block_channels": "block_out_channels",
+ "down_blocks": "down_block_types",
+ "up_blocks": "up_block_types",
+ "downscale_freq_shift": "freq_shift",
+ "resnet_num_groups": "norm_num_groups",
+ "resnet_act_fn": "act_fn",
+ "resnet_eps": "norm_eps",
+ "num_head_channels": "attention_head_dim",
+ }
+
+ key_parameters_to_change = {
+ "time_steps": "time_proj",
+ "mid": "mid_block",
+ "downsample_blocks": "down_blocks",
+ "upsample_blocks": "up_blocks",
+ }
+
+ subfolder = "" if has_file(args.repo_path, "config.json") else "unet"
+
+ with open(os.path.join(args.repo_path, subfolder, "config.json"), "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config = json.loads(text)
+
+ if do_only_config:
+ for key in config_parameters_to_change.keys():
+ config.pop(key, None)
+
+ if has_file(args.repo_path, "config.json"):
+ model = UNet2DModel(**config)
+ else:
+ class_name = UNet2DConditionModel if "ldm-text2im-large-256" in args.repo_path else UNet2DModel
+ model = class_name(**config)
+
+ if do_only_config:
+ model.save_config(os.path.join(args.repo_path, subfolder))
+
+ config = dict(model.config)
+
+ if do_only_renaming:
+ for key, value in config_parameters_to_change.items():
+ if key in config:
+ config[value] = config[key]
+ del config[key]
+
+ config["down_block_types"] = [k.replace("UNetRes", "") for k in config["down_block_types"]]
+ config["up_block_types"] = [k.replace("UNetRes", "") for k in config["up_block_types"]]
+
+ if do_only_weights:
+ state_dict = torch.load(os.path.join(args.repo_path, subfolder, "diffusion_pytorch_model.bin"))
+
+ new_state_dict = {}
+ for param_key, param_value in state_dict.items():
+ if param_key.endswith(".op.bias") or param_key.endswith(".op.weight"):
+ continue
+ has_changed = False
+ for key, new_key in key_parameters_to_change.items():
+ if not has_changed and param_key.split(".")[0] == key:
+ new_state_dict[".".join([new_key] + param_key.split(".")[1:])] = param_value
+ has_changed = True
+ if not has_changed:
+ new_state_dict[param_key] = param_value
+
+ model.load_state_dict(new_state_dict)
+ model.save_pretrained(os.path.join(args.repo_path, subfolder))
diff --git a/diffusers/scripts/conversion_ldm_uncond.py b/diffusers/scripts/conversion_ldm_uncond.py
new file mode 100755
index 0000000..8c22cc1
--- /dev/null
+++ b/diffusers/scripts/conversion_ldm_uncond.py
@@ -0,0 +1,56 @@
+import argparse
+
+import torch
+import yaml
+
+from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
+
+
+def convert_ldm_original(checkpoint_path, config_path, output_path):
+ config = yaml.safe_load(config_path)
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
+ keys = list(state_dict.keys())
+
+ # extract state_dict for VQVAE
+ first_stage_dict = {}
+ first_stage_key = "first_stage_model."
+ for key in keys:
+ if key.startswith(first_stage_key):
+ first_stage_dict[key.replace(first_stage_key, "")] = state_dict[key]
+
+ # extract state_dict for UNetLDM
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
+
+ vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
+ unet_init_args = config["model"]["params"]["unet_config"]["params"]
+
+ vqvae = VQModel(**vqvae_init_args).eval()
+ vqvae.load_state_dict(first_stage_dict)
+
+ unet = UNetLDMModel(**unet_init_args).eval()
+ unet.load_state_dict(unet_state_dict)
+
+ noise_scheduler = DDIMScheduler(
+ timesteps=config["model"]["params"]["timesteps"],
+ beta_schedule="scaled_linear",
+ beta_start=config["model"]["params"]["linear_start"],
+ beta_end=config["model"]["params"]["linear_end"],
+ clip_sample=False,
+ )
+
+ pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
+ pipeline.save_pretrained(output_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", type=str, required=True)
+ parser.add_argument("--config_path", type=str, required=True)
+ parser.add_argument("--output_path", type=str, required=True)
+ args = parser.parse_args()
+
+ convert_ldm_original(args.checkpoint_path, args.config_path, args.output_path)
diff --git a/diffusers/scripts/convert_amused.py b/diffusers/scripts/convert_amused.py
new file mode 100755
index 0000000..21be29d
--- /dev/null
+++ b/diffusers/scripts/convert_amused.py
@@ -0,0 +1,523 @@
+import inspect
+import os
+from argparse import ArgumentParser
+
+import numpy as np
+import torch
+from muse import MaskGiTUViT, VQGANModel
+from muse import PipelineMuse as OldPipelineMuse
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import VQModel
+from diffusers.models.attention_processor import AttnProcessor
+from diffusers.models.unets.uvit_2d import UVit2DModel
+from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
+from diffusers.schedulers import AmusedScheduler
+
+
+torch.backends.cuda.enable_flash_sdp(False)
+torch.backends.cuda.enable_mem_efficient_sdp(False)
+torch.backends.cuda.enable_math_sdp(True)
+
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+torch.use_deterministic_algorithms(True)
+
+# Enable CUDNN deterministic mode
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+torch.backends.cuda.matmul.allow_tf32 = False
+
+device = "cuda"
+
+
+def main():
+ args = ArgumentParser()
+ args.add_argument("--model_256", action="store_true")
+ args.add_argument("--write_to", type=str, required=False, default=None)
+ args.add_argument("--transformer_path", type=str, required=False, default=None)
+ args = args.parse_args()
+
+ transformer_path = args.transformer_path
+ subfolder = "transformer"
+
+ if transformer_path is None:
+ if args.model_256:
+ transformer_path = "openMUSE/muse-256"
+ else:
+ transformer_path = (
+ "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/"
+ )
+ subfolder = None
+
+ old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder)
+
+ old_transformer.to(device)
+
+ old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae")
+ old_vae.to(device)
+
+ vqvae = make_vqvae(old_vae)
+
+ tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
+
+ text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder")
+ text_encoder.to(device)
+
+ transformer = make_transformer(old_transformer, args.model_256)
+
+ scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id)
+
+ new_pipe = AmusedPipeline(
+ vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler
+ )
+
+ old_pipe = OldPipelineMuse(
+ vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer
+ )
+ old_pipe.to(device)
+
+ if args.model_256:
+ transformer_seq_len = 256
+ orig_size = (256, 256)
+ else:
+ transformer_seq_len = 1024
+ orig_size = (512, 512)
+
+ old_out = old_pipe(
+ "dog",
+ generator=torch.Generator(device).manual_seed(0),
+ transformer_seq_len=transformer_seq_len,
+ orig_size=orig_size,
+ timesteps=12,
+ )[0]
+
+ new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0]
+
+ old_out = np.array(old_out)
+ new_out = np.array(new_out)
+
+ diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64))
+
+ # assert diff diff.sum() == 0
+ print("skipping pipeline full equivalence check")
+
+ print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}")
+
+ if args.model_256:
+ assert diff.max() <= 3
+ assert diff.sum() / diff.size < 0.7
+ else:
+ assert diff.max() <= 1
+ assert diff.sum() / diff.size < 0.4
+
+ if args.write_to is not None:
+ new_pipe.save_pretrained(args.write_to)
+
+
+def make_transformer(old_transformer, model_256):
+ args = dict(old_transformer.config)
+ force_down_up_sample = args["force_down_up_sample"]
+
+ signature = inspect.signature(UVit2DModel.__init__)
+
+ args_ = {
+ "downsample": force_down_up_sample,
+ "upsample": force_down_up_sample,
+ "block_out_channels": args["block_out_channels"][0],
+ "sample_size": 16 if model_256 else 32,
+ }
+
+ for s in list(signature.parameters.keys()):
+ if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]:
+ continue
+
+ args_[s] = args[s]
+
+ new_transformer = UVit2DModel(**args_)
+ new_transformer.to(device)
+
+ new_transformer.set_attn_processor(AttnProcessor())
+
+ state_dict = old_transformer.state_dict()
+
+ state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight")
+ state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight")
+
+ for i in range(22):
+ state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.query.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.key.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.value.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.attention.out.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.query.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.key.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.value.weight"
+ )
+ state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.crossattention.out.weight"
+ )
+
+ state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight"
+ )
+ state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop(
+ f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight"
+ )
+
+ wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight")
+ wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight")
+ proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0)
+ state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight
+
+ state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight")
+
+ if force_down_up_sample:
+ state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight")
+ state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight")
+
+ state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight")
+ state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight")
+
+ state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight")
+
+ for i in range(3):
+ state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.norm.norm.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.0.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.2.beta"
+ )
+ state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.channelwise.4.weight"
+ )
+ state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
+ f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.query.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.key.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.value.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.attention.out.weight"
+ )
+
+ state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight"
+ )
+ state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight"
+ )
+
+ state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.norm.norm.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.0.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.2.beta"
+ )
+ state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.channelwise.4.weight"
+ )
+ state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop(
+ f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight"
+ )
+
+ state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.query.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.key.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.value.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.attention.out.weight"
+ )
+
+ state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight"
+ )
+ state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight"
+ )
+
+ for key in list(state_dict.keys()):
+ if key.startswith("up_blocks.0"):
+ key_ = "up_block." + ".".join(key.split(".")[2:])
+ state_dict[key_] = state_dict.pop(key)
+
+ if key.startswith("down_blocks.0"):
+ key_ = "down_block." + ".".join(key.split(".")[2:])
+ state_dict[key_] = state_dict.pop(key)
+
+ new_transformer.load_state_dict(state_dict)
+
+ input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device)
+ encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device)
+ cond_embeds = torch.randn((1, 768), device=old_transformer.device)
+ micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device)
+
+ old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds)
+ old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2)
+
+ new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds)
+
+ # NOTE: these differences are solely due to using the geglu block that has a single linear layer of
+ # double output dimension instead of two different linear layers
+ max_diff = (old_out - new_out).abs().max()
+ total_diff = (old_out - new_out).abs().sum()
+ print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}")
+ assert max_diff < 0.01
+ assert total_diff < 1500
+
+ return new_transformer
+
+
+def make_vqvae(old_vae):
+ new_vae = VQModel(
+ act_fn="silu",
+ block_out_channels=[128, 256, 256, 512, 768],
+ down_block_types=[
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ in_channels=3,
+ latent_channels=64,
+ layers_per_block=2,
+ norm_num_groups=32,
+ num_vq_embeddings=8192,
+ out_channels=3,
+ sample_size=32,
+ up_block_types=[
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ mid_block_add_attention=False,
+ lookup_from_codebook=True,
+ )
+ new_vae.to(device)
+
+ # fmt: off
+
+ new_state_dict = {}
+
+ old_state_dict = old_vae.state_dict()
+
+ new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
+ new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
+
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
+ convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
+
+ new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
+ new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
+ new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
+ new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
+ new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
+ new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
+ new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
+ new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
+ new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
+ new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
+ new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
+ new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
+ new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")
+ new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias")
+ new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight")
+ new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias")
+ new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight")
+ new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias")
+ new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight")
+ new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias")
+ new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight")
+ new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias")
+ new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight")
+ new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight")
+ new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias")
+ new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight")
+ new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias")
+ new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight")
+ new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias")
+ new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight")
+ new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias")
+ new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight")
+ new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias")
+ new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight")
+ new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias")
+ new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight")
+ new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias")
+ new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight")
+ new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias")
+ new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight")
+ new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias")
+ new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight")
+ new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias")
+
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1")
+ convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0")
+
+ new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight")
+ new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias")
+ new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight")
+ new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias")
+
+ # fmt: on
+
+ assert len(old_state_dict.keys()) == 0
+
+ new_vae.load_state_dict(new_state_dict)
+
+ input = torch.randn((1, 3, 512, 512), device=device)
+ input = input.clamp(-1, 1)
+
+ old_encoder_output = old_vae.quant_conv(old_vae.encoder(input))
+ new_encoder_output = new_vae.quant_conv(new_vae.encoder(input))
+ assert (old_encoder_output == new_encoder_output).all()
+
+ old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output))
+ new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output))
+
+ # assert (old_decoder_output == new_decoder_output).all()
+ print("kipping vae decoder equivalence check")
+ print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}")
+
+ old_output = old_vae(input)[0]
+ new_output = new_vae(input)[0]
+
+ # assert (old_output == new_output).all()
+ print("skipping full vae equivalence check")
+ print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
+
+ return new_vae
+
+
+def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to):
+ # fmt: off
+
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias")
+
+ if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight")
+ new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias")
+
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias")
+
+ if f"{prefix_from}.downsample.conv.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight")
+ new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias")
+
+ if f"{prefix_from}.upsample.conv.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight")
+ new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias")
+
+ if f"{prefix_from}.block.2.norm1.weight" in old_state_dict:
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias")
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight")
+ new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias")
+
+ # fmt: on
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py
new file mode 100755
index 0000000..21567ff
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_motion_lora_to_diffusers.py
@@ -0,0 +1,69 @@
+import argparse
+import os
+
+import torch
+from huggingface_hub import create_repo, upload_folder
+from safetensors.torch import load_file, save_file
+
+
+def convert_motion_module(original_state_dict):
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if "pos_encoder" in k:
+ continue
+
+ else:
+ converted_state_dict[
+ k.replace(".norms.0", ".norm1")
+ .replace(".norms.1", ".norm2")
+ .replace(".ff_norm", ".norm3")
+ .replace(".attention_blocks.0", ".attn1")
+ .replace(".attention_blocks.1", ".attn2")
+ .replace(".temporal_transformer", "")
+ ] = v
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ default=False,
+ help="Whether to push the converted model to the HF or not",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.ckpt_path.endswith(".safetensors"):
+ state_dict = load_file(args.ckpt_path)
+ else:
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+
+ conv_state_dict = convert_motion_module(state_dict)
+
+ # convert to new format
+ output_dict = {}
+ for module_name, params in conv_state_dict.items():
+ if type(params) is not torch.Tensor:
+ continue
+ output_dict.update({f"unet.{module_name}": params})
+
+ os.makedirs(args.output_path, exist_ok=True)
+
+ filepath = os.path.join(args.output_path, "diffusion_pytorch_model.safetensors")
+ save_file(output_dict, filepath)
+
+ if args.push_to_hub:
+ repo_id = create_repo(args.output_path, exist_ok=True).repo_id
+ upload_folder(repo_id=repo_id, folder_path=args.output_path, repo_type="model")
diff --git a/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py
new file mode 100755
index 0000000..e188a6a
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_motion_module_to_diffusers.py
@@ -0,0 +1,62 @@
+import argparse
+
+import torch
+from safetensors.torch import load_file
+
+from diffusers import MotionAdapter
+
+
+def convert_motion_module(original_state_dict):
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if "pos_encoder" in k:
+ continue
+
+ else:
+ converted_state_dict[
+ k.replace(".norms.0", ".norm1")
+ .replace(".norms.1", ".norm2")
+ .replace(".ff_norm", ".norm3")
+ .replace(".attention_blocks.0", ".attn1")
+ .replace(".attention_blocks.1", ".attn2")
+ .replace(".temporal_transformer", "")
+ ] = v
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--use_motion_mid_block", action="store_true")
+ parser.add_argument("--motion_max_seq_length", type=int, default=32)
+ parser.add_argument("--block_out_channels", nargs="+", default=[320, 640, 1280, 1280], type=int)
+ parser.add_argument("--save_fp16", action="store_true")
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.ckpt_path.endswith(".safetensors"):
+ state_dict = load_file(args.ckpt_path)
+ else:
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+
+ if "state_dict" in state_dict.keys():
+ state_dict = state_dict["state_dict"]
+
+ conv_state_dict = convert_motion_module(state_dict)
+ adapter = MotionAdapter(
+ block_out_channels=args.block_out_channels,
+ use_motion_mid_block=args.use_motion_mid_block,
+ motion_max_seq_length=args.motion_max_seq_length,
+ )
+ # skip loading position embeddings
+ adapter.load_state_dict(conv_state_dict, strict=False)
+ adapter.save_pretrained(args.output_path)
+
+ if args.save_fp16:
+ adapter.to(dtype=torch.float16).save_pretrained(args.output_path, variant="fp16")
diff --git a/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py
new file mode 100755
index 0000000..f246dce
--- /dev/null
+++ b/diffusers/scripts/convert_animatediff_sparsectrl_to_diffusers.py
@@ -0,0 +1,83 @@
+import argparse
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from diffusers import SparseControlNetModel
+
+
+KEYS_RENAME_MAPPING = {
+ ".attention_blocks.0": ".attn1",
+ ".attention_blocks.1": ".attn2",
+ ".attn1.pos_encoder": ".pos_embed",
+ ".ff_norm": ".norm3",
+ ".norms.0": ".norm1",
+ ".norms.1": ".norm2",
+ ".temporal_transformer": "",
+}
+
+
+def convert(original_state_dict: Dict[str, nn.Module]) -> Dict[str, nn.Module]:
+ converted_state_dict = {}
+
+ for key in list(original_state_dict.keys()):
+ renamed_key = key
+ for new_name, old_name in KEYS_RENAME_MAPPING.items():
+ renamed_key = renamed_key.replace(new_name, old_name)
+ converted_state_dict[renamed_key] = original_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ckpt_path", type=str, required=True, help="Path to checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path to output directory")
+ parser.add_argument(
+ "--max_motion_seq_length",
+ type=int,
+ default=32,
+ help="Max motion sequence length supported by the motion adapter",
+ )
+ parser.add_argument(
+ "--conditioning_channels", type=int, default=4, help="Number of channels in conditioning input to controlnet"
+ )
+ parser.add_argument(
+ "--use_simplified_condition_embedding",
+ action="store_true",
+ default=False,
+ help="Whether or not to use simplified condition embedding. When `conditioning_channels==4` i.e. latent inputs, set this to `True`. When `conditioning_channels==3` i.e. image inputs, set this to `False`",
+ )
+ parser.add_argument(
+ "--save_fp16",
+ action="store_true",
+ default=False,
+ help="Whether or not to save model in fp16 precision along with fp32",
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", default=False, help="Whether or not to push saved model to the HF hub"
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ state_dict = torch.load(args.ckpt_path, map_location="cpu")
+ if "state_dict" in state_dict.keys():
+ state_dict: dict = state_dict["state_dict"]
+
+ controlnet = SparseControlNetModel(
+ conditioning_channels=args.conditioning_channels,
+ motion_max_seq_length=args.max_motion_seq_length,
+ use_simplified_condition_embedding=args.use_simplified_condition_embedding,
+ )
+
+ state_dict = convert(state_dict)
+ controlnet.load_state_dict(state_dict, strict=True)
+
+ controlnet.save_pretrained(args.output_path, push_to_hub=args.push_to_hub)
+ if args.save_fp16:
+ controlnet = controlnet.to(dtype=torch.float16)
+ controlnet.save_pretrained(args.output_path, variant="fp16", push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
new file mode 100755
index 0000000..ffb735e
--- /dev/null
+++ b/diffusers/scripts/convert_asymmetric_vqgan_to_diffusers.py
@@ -0,0 +1,184 @@
+import argparse
+import time
+from pathlib import Path
+from typing import Any, Dict, Literal
+
+import torch
+
+from diffusers import AsymmetricAutoencoderKL
+
+
+ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [192, 384, 768, 768],
+ "layers_per_up_block": 3,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ],
+ "down_block_out_channels": [128, 256, 512, 512],
+ "layers_per_down_block": 2,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ],
+ "up_block_out_channels": [256, 512, 1024, 1024],
+ "layers_per_up_block": 5,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 256,
+ "scaling_factor": 0.18215,
+}
+
+
+def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ converted_state_dict = {}
+ for k, v in original_state_dict.items():
+ if k.startswith("encoder."):
+ converted_state_dict[
+ k.replace("encoder.down.", "encoder.down_blocks.")
+ .replace("encoder.mid.", "encoder.mid_block.")
+ .replace("encoder.norm_out.", "encoder.conv_norm_out.")
+ .replace(".downsample.", ".downsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block.", ".resnets.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("decoder.") and "up_layers" not in k:
+ converted_state_dict[
+ k.replace("decoder.encoder.", "decoder.condition_encoder.")
+ .replace(".norm_out.", ".conv_norm_out.")
+ .replace(".up.0.", ".up_blocks.3.")
+ .replace(".up.1.", ".up_blocks.2.")
+ .replace(".up.2.", ".up_blocks.1.")
+ .replace(".up.3.", ".up_blocks.0.")
+ .replace(".block.", ".resnets.")
+ .replace("mid", "mid_block")
+ .replace(".0.upsample.", ".0.upsamplers.0.")
+ .replace(".1.upsample.", ".1.upsamplers.0.")
+ .replace(".2.upsample.", ".2.upsamplers.0.")
+ .replace(".nin_shortcut.", ".conv_shortcut.")
+ .replace(".block_1.", ".resnets.0.")
+ .replace(".block_2.", ".resnets.1.")
+ .replace(".attn_1.k.", ".attentions.0.to_k.")
+ .replace(".attn_1.q.", ".attentions.0.to_q.")
+ .replace(".attn_1.v.", ".attentions.0.to_v.")
+ .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
+ .replace(".attn_1.norm.", ".attentions.0.group_norm.")
+ ] = v
+ elif k.startswith("quant_conv."):
+ converted_state_dict[k] = v
+ elif k.startswith("post_quant_conv."):
+ converted_state_dict[k] = v
+ else:
+ print(f" skipping key `{k}`")
+ # fix weights shape
+ for k, v in converted_state_dict.items():
+ if (
+ (k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
+ and k.endswith("weight")
+ and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
+ ):
+ converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
+
+ return converted_state_dict
+
+
+def get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
+) -> AsymmetricAutoencoderKL:
+ print("Loading original state_dict")
+ original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
+ original_state_dict = original_state_dict["state_dict"]
+ print("Converting state_dict")
+ converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
+ kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
+ print("Initializing AsymmetricAutoencoderKL model")
+ asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
+ print("Loading weight from converted state_dict")
+ asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
+ asymmetric_autoencoder_kl.eval()
+ print("AsymmetricAutoencoderKL successfully initialized")
+ return asymmetric_autoencoder_kl
+
+
+if __name__ == "__main__":
+ start = time.time()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--scale",
+ default=None,
+ type=str,
+ required=True,
+ help="Asymmetric VQGAN scale: `1.5` or `2`",
+ )
+ parser.add_argument(
+ "--original_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the original Asymmetric VQGAN checkpoint",
+ )
+ parser.add_argument(
+ "--output_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to save pretrained AsymmetricAutoencoderKL model",
+ )
+ parser.add_argument(
+ "--map_location",
+ default="cpu",
+ type=str,
+ required=False,
+ help="The device passed to `map_location` when loading the checkpoint",
+ )
+ args = parser.parse_args()
+
+ assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
+ assert Path(args.original_checkpoint_path).is_file()
+
+ asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
+ scale=args.scale,
+ original_checkpoint_path=args.original_checkpoint_path,
+ map_location=torch.device(args.map_location),
+ )
+ print("Saving pretrained AsymmetricAutoencoderKL")
+ asymmetric_autoencoder_kl.save_pretrained(args.output_path)
+ print(f"Done in {time.time() - start:.2f} seconds")
diff --git a/diffusers/scripts/convert_aura_flow_to_diffusers.py b/diffusers/scripts/convert_aura_flow_to_diffusers.py
new file mode 100755
index 0000000..74c34f4
--- /dev/null
+++ b/diffusers/scripts/convert_aura_flow_to_diffusers.py
@@ -0,0 +1,131 @@
+import argparse
+
+import torch
+from huggingface_hub import hf_hub_download
+
+from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
+
+
+def load_original_state_dict(args):
+ model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
+ state_dict = torch.load(model_pt, map_location="cpu")
+ return state_dict
+
+
+def calculate_layers(state_dict_keys, key_prefix):
+ dit_layers = set()
+ for k in state_dict_keys:
+ if key_prefix in k:
+ dit_layers.add(int(k.split(".")[2]))
+ print(f"{key_prefix}: {len(dit_layers)}")
+ return len(dit_layers)
+
+
+# similar to SD3 but only for the last norm layer
+def swap_scale_shift(weight, dim):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_transformer(state_dict):
+ converted_state_dict = {}
+ state_dict_keys = list(state_dict.keys())
+
+ converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
+ converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
+ converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
+ converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
+
+ converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
+ converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
+ converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
+ converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
+
+ converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
+
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ # MMDiT blocks 🎸.
+ for i in range(mmdit_layers):
+ # feed-forward
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for orig_k, diffuser_k in path_mapping.items():
+ for k, v in weight_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.{orig_k}.{k}.weight"
+ )
+
+ # norms
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
+ for orig_k, diffuser_k in path_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.{orig_k}.1.weight"
+ )
+
+ # attns
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
+ for k, v in attn_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
+ f"model.double_layers.{i}.attn.{k}.weight"
+ )
+
+ # Single-DiT blocks.
+ for i in range(single_dit_layers):
+ # feed-forward
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for k, v in mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.mlp.{k}.weight"
+ )
+
+ # norms
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.modCX.1.weight"
+ )
+
+ # attns
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
+ for k, v in x_attn_mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
+ f"model.single_layers.{i}.attn.{k}.weight"
+ )
+
+ # Final blocks.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
+
+ return converted_state_dict
+
+
+@torch.no_grad()
+def populate_state_dict(args):
+ original_state_dict = load_original_state_dict(args)
+ state_dict_keys = list(original_state_dict.keys())
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ converted_state_dict = convert_transformer(original_state_dict)
+ model_diffusers = AuraFlowTransformer2DModel(
+ num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
+ )
+ model_diffusers.load_state_dict(converted_state_dict, strict=True)
+
+ return model_diffusers
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
+ parser.add_argument("--dump_path", default="aura-flow", type=str)
+ parser.add_argument("--hub_id", default=None, type=str)
+ args = parser.parse_args()
+
+ model_diffusers = populate_state_dict(args)
+ model_diffusers.save_pretrained(args.dump_path)
+ if args.hub_id is not None:
+ model_diffusers.push_to_hub(args.hub_id)
diff --git a/diffusers/scripts/convert_blipdiffusion_to_diffusers.py b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
new file mode 100755
index 0000000..03cf67e
--- /dev/null
+++ b/diffusers/scripts/convert_blipdiffusion_to_diffusers.py
@@ -0,0 +1,343 @@
+"""
+This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
+"""
+
+import argparse
+import os
+import tempfile
+
+import torch
+from lavis.models import load_model_and_preprocess
+from transformers import CLIPTokenizer
+from transformers.models.blip_2.configuration_blip_2 import Blip2Config
+
+from diffusers import (
+ AutoencoderKL,
+ PNDMScheduler,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines import BlipDiffusionPipeline
+from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
+from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
+from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
+
+
+BLIP2_CONFIG = {
+ "vision_config": {
+ "hidden_size": 1024,
+ "num_hidden_layers": 23,
+ "num_attention_heads": 16,
+ "image_size": 224,
+ "patch_size": 14,
+ "intermediate_size": 4096,
+ "hidden_act": "quick_gelu",
+ },
+ "qformer_config": {
+ "cross_attention_frequency": 1,
+ "encoder_hidden_size": 1024,
+ "vocab_size": 30523,
+ },
+ "num_query_tokens": 16,
+}
+blip2config = Blip2Config(**BLIP2_CONFIG)
+
+
+def qformer_model_from_original_config():
+ qformer = Blip2QFormerModel(blip2config)
+ return qformer
+
+
+def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
+ embeddings = {}
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
+ f"{original_embeddings_prefix}.word_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {
+ f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
+ f"{original_embeddings_prefix}.position_embeddings.weight"
+ ]
+ }
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
+ )
+ embeddings.update(
+ {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
+ )
+ return embeddings
+
+
+def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
+ proj_layer = {}
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
+ proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
+ return proj_layer
+
+
+def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
+ attention = {}
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.query.weight": model[
+ f"{original_attention_prefix}.self.query.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.attention.value.weight": model[
+ f"{original_attention_prefix}.self.value.weight"
+ ]
+ }
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
+ )
+ attention.update(
+ {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
+ f"{original_attention_prefix}.output.LayerNorm.weight"
+ ]
+ }
+ )
+ attention.update(
+ {
+ f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
+ f"{original_attention_prefix}.output.LayerNorm.bias"
+ ]
+ }
+ )
+ return attention
+
+
+def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
+ output_layers = {}
+ output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
+ output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
+ )
+ output_layers.update(
+ {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
+ )
+ return output_layers
+
+
+def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
+ encoder = {}
+ for i in range(blip2config.qformer_config.num_hidden_layers):
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
+ )
+ )
+ encoder.update(
+ attention_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
+ )
+ )
+
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
+ ]
+ }
+ )
+ encoder.update(
+ {
+ f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
+ f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
+ ]
+ }
+ )
+
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
+ )
+ )
+ encoder.update(
+ output_layers_from_original_checkpoint(
+ model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
+ )
+ )
+ return encoder
+
+
+def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder_layer = {}
+
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
+ )
+ visual_encoder_layer.update(
+ {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
+ )
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
+ visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
+
+ return visual_encoder_layer
+
+
+def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
+ visual_encoder = {}
+
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
+ .unsqueeze(0)
+ .unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {
+ f"{diffuser_prefix}.embeddings.position_embedding": model[
+ f"{original_prefix}.positional_embedding"
+ ].unsqueeze(0)
+ }
+ )
+ visual_encoder.update(
+ {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
+ )
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
+
+ for i in range(blip2config.vision_config.num_hidden_layers):
+ visual_encoder.update(
+ visual_encoder_layer_from_original_checkpoint(
+ model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
+ )
+ )
+
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
+ visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
+
+ return visual_encoder
+
+
+def qformer_original_checkpoint_to_diffusers_checkpoint(model):
+ qformer_checkpoint = {}
+ qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
+ qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
+ qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
+ qformer_checkpoint.update(
+ encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
+ )
+ qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
+ return qformer_checkpoint
+
+
+def get_qformer(model):
+ print("loading qformer")
+
+ qformer = qformer_model_from_original_config()
+ qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
+
+ load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
+
+ print("done loading qformer")
+ return qformer
+
+
+def load_checkpoint_to_model(checkpoint, model):
+ with tempfile.NamedTemporaryFile(delete=False) as file:
+ torch.save(checkpoint, file.name)
+ del checkpoint
+ model.load_state_dict(torch.load(file.name), strict=False)
+
+ os.remove(file.name)
+
+
+def save_blip_diffusion_model(model, args):
+ qformer = get_qformer(model)
+ qformer.eval()
+
+ text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ vae.eval()
+ text_encoder.eval()
+ scheduler = PNDMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ set_alpha_to_one=False,
+ skip_prk_steps=True,
+ )
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
+ image_processor = BlipImageProcessor()
+ blip_diffusion = BlipDiffusionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ scheduler=scheduler,
+ qformer=qformer,
+ image_processor=image_processor,
+ )
+ blip_diffusion.save_pretrained(args.checkpoint_path)
+
+
+def main(args):
+ model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
+ save_blip_diffusion_model(model.state_dict(), args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_cogvideox_to_diffusers.py b/diffusers/scripts/convert_cogvideox_to_diffusers.py
new file mode 100755
index 0000000..4343eaf
--- /dev/null
+++ b/diffusers/scripts/convert_cogvideox_to_diffusers.py
@@ -0,0 +1,290 @@
+import argparse
+from typing import Any, Dict
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers import (
+ AutoencoderKLCogVideoX,
+ CogVideoXDDIMScheduler,
+ CogVideoXImageToVideoPipeline,
+ CogVideoXPipeline,
+ CogVideoXTransformer3DModel,
+)
+
+
+def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
+ to_q_key = key.replace("query_key_value", "to_q")
+ to_k_key = key.replace("query_key_value", "to_k")
+ to_v_key = key.replace("query_key_value", "to_v")
+ to_q, to_k, to_v = torch.chunk(state_dict[key], chunks=3, dim=0)
+ state_dict[to_q_key] = to_q
+ state_dict[to_k_key] = to_k
+ state_dict[to_v_key] = to_v
+ state_dict.pop(key)
+
+
+def reassign_query_key_layernorm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, weight_or_bias = key.split(".")[-2:]
+
+ if "query" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_q.{weight_or_bias}"
+ elif "key" in key:
+ new_key = f"transformer_blocks.{layer_id}.attn1.norm_k.{weight_or_bias}"
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+def reassign_adaln_norm_inplace(key: str, state_dict: Dict[str, Any]):
+ layer_id, _, weight_or_bias = key.split(".")[-3:]
+
+ weights_or_biases = state_dict[key].chunk(12, dim=0)
+ norm1_weights_or_biases = torch.cat(weights_or_biases[0:3] + weights_or_biases[6:9])
+ norm2_weights_or_biases = torch.cat(weights_or_biases[3:6] + weights_or_biases[9:12])
+
+ norm1_key = f"transformer_blocks.{layer_id}.norm1.linear.{weight_or_bias}"
+ state_dict[norm1_key] = norm1_weights_or_biases
+
+ norm2_key = f"transformer_blocks.{layer_id}.norm2.linear.{weight_or_bias}"
+ state_dict[norm2_key] = norm2_weights_or_biases
+
+ state_dict.pop(key)
+
+
+def remove_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ state_dict.pop(key)
+
+
+def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
+ key_split = key.split(".")
+ layer_index = int(key_split[2])
+ replace_layer_index = 4 - 1 - layer_index
+
+ key_split[1] = "up_blocks"
+ key_split[2] = str(replace_layer_index)
+ new_key = ".".join(key_split)
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "transformer.final_layernorm": "norm_final",
+ "transformer": "transformer_blocks",
+ "attention": "attn1",
+ "mlp": "ff.net",
+ "dense_h_to_4h": "0.proj",
+ "dense_4h_to_h": "2",
+ ".layers": "",
+ "dense": "to_out.0",
+ "input_layernorm": "norm1.norm",
+ "post_attn1_layernorm": "norm2.norm",
+ "time_embed.0": "time_embedding.linear_1",
+ "time_embed.2": "time_embedding.linear_2",
+ "mixins.patch_embed": "patch_embed",
+ "mixins.final_layer.norm_final": "norm_out.norm",
+ "mixins.final_layer.linear": "proj_out",
+ "mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
+ "mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "query_key_value": reassign_query_key_value_inplace,
+ "query_layernorm_list": reassign_query_key_layernorm_inplace,
+ "key_layernorm_list": reassign_query_key_layernorm_inplace,
+ "adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
+ "embed_tokens": remove_keys_inplace,
+ "freqs_sin": remove_keys_inplace,
+ "freqs_cos": remove_keys_inplace,
+ "position_embedding": remove_keys_inplace,
+}
+
+VAE_KEYS_RENAME_DICT = {
+ "block.": "resnets.",
+ "down.": "down_blocks.",
+ "downsample": "downsamplers.0",
+ "upsample": "upsamplers.0",
+ "nin_shortcut": "conv_shortcut",
+ "encoder.mid.block_1": "encoder.mid_block.resnets.0",
+ "encoder.mid.block_2": "encoder.mid_block.resnets.1",
+ "decoder.mid.block_1": "decoder.mid_block.resnets.0",
+ "decoder.mid.block_2": "decoder.mid_block.resnets.1",
+}
+
+VAE_SPECIAL_KEYS_REMAP = {
+ "loss": remove_keys_inplace,
+ "up.": replace_up_keys_inplace,
+}
+
+TOKENIZER_MAX_LENGTH = 226
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def convert_transformer(
+ ckpt_path: str,
+ num_layers: int,
+ num_attention_heads: int,
+ use_rotary_positional_embeddings: bool,
+ i2v: bool,
+ dtype: torch.dtype,
+):
+ PREFIX_KEY = "model.diffusion_model."
+
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ transformer = CogVideoXTransformer3DModel(
+ in_channels=32 if i2v else 16,
+ num_layers=num_layers,
+ num_attention_heads=num_attention_heads,
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=i2v,
+ ).to(dtype=dtype)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[len(PREFIX_KEY) :]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+ transformer.load_state_dict(original_state_dict, strict=True)
+ return transformer
+
+
+def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
+ vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_inplace(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
+ parser.add_argument(
+ "--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
+ )
+ parser.add_argument(
+ "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
+ )
+ # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
+ parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
+ # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
+ parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
+ # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
+ parser.add_argument(
+ "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
+ )
+ # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
+ parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
+ # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
+ parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
+ parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ vae = None
+
+ if args.fp16 and args.bf16:
+ raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
+
+ dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
+
+ if args.transformer_ckpt_path is not None:
+ transformer = convert_transformer(
+ args.transformer_ckpt_path,
+ args.num_layers,
+ args.num_attention_heads,
+ args.use_rotary_positional_embeddings,
+ args.i2v,
+ dtype,
+ )
+ if args.vae_ckpt_path is not None:
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
+
+ text_encoder_id = "google/t5-v1_1-xxl"
+ tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
+ text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+
+ # Apparently, the conversion does not work anymore without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
+ scheduler = CogVideoXDDIMScheduler.from_config(
+ {
+ "snr_shift_scale": args.snr_shift_scale,
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": False,
+ "num_train_timesteps": 1000,
+ "prediction_type": "v_prediction",
+ "rescale_betas_zero_snr": True,
+ "set_alpha_to_one": True,
+ "timestep_spacing": "trailing",
+ }
+ )
+ if args.i2v:
+ pipeline_cls = CogVideoXImageToVideoPipeline
+ else:
+ pipeline_cls = CogVideoXPipeline
+
+ pipe = pipeline_cls(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ vae=vae,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ if args.fp16:
+ pipe = pipe.to(dtype=torch.float16)
+ if args.bf16:
+ pipe = pipe.to(dtype=torch.bfloat16)
+
+ # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
+ # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
+ # is either fp16/bf16 here).
+
+ # This is necessary This is necessary for users with insufficient memory,
+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_consistency_decoder.py b/diffusers/scripts/convert_consistency_decoder.py
new file mode 100755
index 0000000..0cb5fc5
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_decoder.py
@@ -0,0 +1,1128 @@
+import math
+import os
+import urllib
+import warnings
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from huggingface_hub.utils import insecure_hashlib
+from safetensors.torch import load_file as stl
+from tqdm import tqdm
+
+from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
+from diffusers.models.autoencoders.vae import Encoder
+from diffusers.models.embeddings import TimestepEmbedding
+from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
+
+
+args = ArgumentParser()
+args.add_argument("--save_pretrained", required=False, default=None, type=str)
+args.add_argument("--test_image", required=True, type=str)
+args = args.parse_args()
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ res = arr[timesteps].float()
+ dims_to_append = len(broadcast_shape) - len(res.shape)
+ return res[(...,) + (None,) * dims_to_append]
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas)
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+class ConsistencyDecoder:
+ def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
+ self.n_distilled_steps = 64
+ download_target = _download(
+ "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
+ download_root,
+ )
+ self.ckpt = torch.jit.load(download_target).to(device)
+ self.device = device
+ sigma_data = 0.5
+ betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
+
+ @staticmethod
+ def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
+ with torch.no_grad():
+ space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
+ rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
+ if truncate_start:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ else:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ rounded_timesteps[rounded_timesteps == 0] += space
+ return rounded_timesteps
+
+ @staticmethod
+ def ldm_transform_latent(z, extra_scale_factor=1):
+ channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
+ channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
+
+ if len(z.shape) != 4:
+ raise ValueError()
+
+ z = z * 0.18215
+ channels = [z[:, i] for i in range(z.shape[1])]
+
+ channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
+ return torch.stack(channels, dim=1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ features: torch.Tensor,
+ schedule=[1.0, 0.5],
+ generator=None,
+ ):
+ features = self.ldm_transform_latent(features)
+ ts = self.round_timesteps(
+ torch.arange(0, 1024),
+ 1024,
+ self.n_distilled_steps,
+ truncate_start=False,
+ )
+ shape = (
+ features.size(0),
+ 3,
+ 8 * features.size(2),
+ 8 * features.size(3),
+ )
+ x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
+ schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
+ for i in schedule_timesteps:
+ t = ts[i].item()
+ t_ = torch.tensor([t] * features.shape[0]).to(self.device)
+ # noise = torch.randn_like(x_start)
+ noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
+ x_start = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
+ )
+ c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
+
+ import torch.nn.functional as F
+
+ from diffusers import UNet2DModel
+
+ if isinstance(self.ckpt, UNet2DModel):
+ input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
+ model_output = self.ckpt(input, t_).sample
+ else:
+ model_output = self.ckpt(c_in * x_start, t_, features=features)
+
+ B, C = x_start.shape[:2]
+ model_output, _ = torch.split(model_output, C, dim=1)
+ pred_xstart = (
+ _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
+ + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
+ ).clamp(-1, 1)
+ x_start = pred_xstart
+ return x_start
+
+
+def save_image(image, name):
+ import numpy as np
+ from PIL import Image
+
+ image = image[0].cpu().numpy()
+ image = (image + 1.0) * 127.5
+ image = image.clip(0, 255).astype(np.uint8)
+ image = Image.fromarray(image.transpose(1, 2, 0))
+ image.save(name)
+
+
+def load_image(uri, size=None, center_crop=False):
+ import numpy as np
+ from PIL import Image
+
+ image = Image.open(uri)
+ if center_crop:
+ image = image.crop(
+ (
+ (image.width - min(image.width, image.height)) // 2,
+ (image.height - min(image.width, image.height)) // 2,
+ (image.width + min(image.width, image.height)) // 2,
+ (image.height + min(image.width, image.height)) // 2,
+ )
+ )
+ if size is not None:
+ image = image.resize(size)
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
+ image = image / 127.5 - 1.0
+ return image
+
+
+class TimestepEmbedding_(nn.Module):
+ def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
+ super().__init__()
+ self.emb = nn.Embedding(n_time, n_emb)
+ self.f_1 = nn.Linear(n_emb, n_out)
+ self.f_2 = nn.Linear(n_out, n_out)
+
+ def forward(self, x) -> torch.Tensor:
+ x = self.emb(x)
+ x = self.f_1(x)
+ x = F.silu(x)
+ return self.f_2(x)
+
+
+class ImageEmbedding(nn.Module):
+ def __init__(self, in_channels=7, out_channels=320) -> None:
+ super().__init__()
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(x)
+
+
+class ImageUnembedding(nn.Module):
+ def __init__(self, in_channels=320, out_channels=6) -> None:
+ super().__init__()
+ self.gn = nn.GroupNorm(32, in_channels)
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(F.silu(self.gn(x)))
+
+
+class ConvResblock(nn.Module):
+ def __init__(self, in_features=320, out_features=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, out_features * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_features)
+ self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
+
+ self.gn_2 = nn.GroupNorm(32, out_features)
+ self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
+
+ skip_conv = in_features != out_features
+ self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
+
+ def forward(self, x, t):
+ x_skip = x
+ t = self.f_t(F.silu(t))
+ t = t.chunk(2, dim=1)
+ t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
+ t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ f_1 = self.f_1(gn_1)
+
+ gn_2 = self.gn_2(f_1)
+
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
+
+
+# Also ConvResblock
+class Downsample(nn.Module):
+ def __init__(self, in_channels=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
+
+ f_1 = self.f_1(avg_pool2d)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
+
+
+# Also ConvResblock
+class Upsample(nn.Module):
+ def __init__(self, in_channels=1024) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ upsample = F.upsample_nearest(gn_1, scale_factor=2)
+ f_1 = self.f_1(upsample)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
+
+
+class ConvUNetVAE(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.embed_image = ImageEmbedding()
+ self.embed_time = TimestepEmbedding_()
+
+ down_0 = nn.ModuleList(
+ [
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ Downsample(320),
+ ]
+ )
+ down_1 = nn.ModuleList(
+ [
+ ConvResblock(320, 640),
+ ConvResblock(640, 640),
+ ConvResblock(640, 640),
+ Downsample(640),
+ ]
+ )
+ down_2 = nn.ModuleList(
+ [
+ ConvResblock(640, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ Downsample(1024),
+ ]
+ )
+ down_3 = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+ self.down = nn.ModuleList(
+ [
+ down_0,
+ down_1,
+ down_2,
+ down_3,
+ ]
+ )
+
+ self.mid = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+
+ up_3 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_2 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 + 640, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_1 = nn.ModuleList(
+ [
+ ConvResblock(1024 + 640, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(320 + 640, 640),
+ Upsample(640),
+ ]
+ )
+ up_0 = nn.ModuleList(
+ [
+ ConvResblock(320 + 640, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ]
+ )
+ self.up = nn.ModuleList(
+ [
+ up_0,
+ up_1,
+ up_2,
+ up_3,
+ ]
+ )
+
+ self.output = ImageUnembedding()
+
+ def forward(self, x, t, features) -> torch.Tensor:
+ converted = hasattr(self, "converted") and self.converted
+
+ x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
+
+ if converted:
+ t = self.time_embedding(self.time_proj(t))
+ else:
+ t = self.embed_time(t)
+
+ x = self.embed_image(x)
+
+ skips = [x]
+ for i, down in enumerate(self.down):
+ if converted and i in [0, 1, 2, 3]:
+ x, skips_ = down(x, t)
+ for skip in skips_:
+ skips.append(skip)
+ else:
+ for block in down:
+ x = block(x, t)
+ skips.append(x)
+ print(x.float().abs().sum())
+
+ if converted:
+ x = self.mid(x, t)
+ else:
+ for i in range(2):
+ x = self.mid[i](x, t)
+ print(x.float().abs().sum())
+
+ for i, up in enumerate(self.up[::-1]):
+ if converted and i in [0, 1, 2, 3]:
+ skip_4 = skips.pop()
+ skip_3 = skips.pop()
+ skip_2 = skips.pop()
+ skip_1 = skips.pop()
+ skips_ = (skip_1, skip_2, skip_3, skip_4)
+ x = up(x, skips_, t)
+ else:
+ for block in up:
+ if isinstance(block, ConvResblock):
+ x = torch.concat([x, skips.pop()], dim=1)
+ x = block(x, t)
+
+ return self.output(x)
+
+
+def rename_state_dict_key(k):
+ k = k.replace("blocks.", "")
+ for i in range(5):
+ k = k.replace(f"down_{i}_", f"down.{i}.")
+ k = k.replace(f"conv_{i}.", f"{i}.")
+ k = k.replace(f"up_{i}_", f"up.{i}.")
+ k = k.replace(f"mid_{i}", f"mid.{i}")
+ k = k.replace("upsamp.", "4.")
+ k = k.replace("downsamp.", "3.")
+ k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
+ k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
+ k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
+ k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
+ k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
+ k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
+ k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
+ k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
+ return k
+
+
+def rename_state_dict(sd, embedding):
+ sd = {rename_state_dict_key(k): v for k, v in sd.items()}
+ sd["embed_time.emb.weight"] = embedding["weight"]
+ return sd
+
+
+# encode with stable diffusion vae
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe.vae.cuda()
+
+# construct original decoder with jitted model
+decoder_consistency = ConsistencyDecoder(device="cuda:0")
+
+# construct UNet code, overwrite the decoder with conv_unet_vae
+model = ConvUNetVAE()
+model.load_state_dict(
+ rename_state_dict(
+ stl("consistency_decoder.safetensors"),
+ stl("embedding.safetensors"),
+ )
+)
+model = model.cuda()
+
+decoder_consistency.ckpt = model
+
+image = load_image(args.test_image, size=(256, 256), center_crop=True)
+latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
+
+# decode with gan
+sample_gan = pipe.vae.decode(latent).sample.detach()
+save_image(sample_gan, "gan.png")
+
+# decode with conv_unet_vae
+sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_orig, "con_orig.png")
+
+
+########### conversion
+
+print("CONVERSION")
+
+print("DOWN BLOCK ONE")
+
+block_one_sd_orig = model.down[0].state_dict()
+block_one_sd_new = {}
+
+for i in range(3):
+ block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
+ block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
+ block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
+ block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
+block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
+block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
+block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
+block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
+block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
+block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
+block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
+block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
+block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
+
+assert len(block_one_sd_orig) == 0
+
+block_one = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_one.load_state_dict(block_one_sd_new)
+
+print("DOWN BLOCK TWO")
+
+block_two_sd_orig = model.down[1].state_dict()
+block_two_sd_new = {}
+
+for i in range(3):
+ block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
+ block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
+ block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
+ block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
+block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
+block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
+block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
+block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
+block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
+block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
+block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
+block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
+block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
+
+assert len(block_two_sd_orig) == 0
+
+block_two = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_two.load_state_dict(block_two_sd_new)
+
+print("DOWN BLOCK THREE")
+
+block_three_sd_orig = model.down[2].state_dict()
+block_three_sd_new = {}
+
+for i in range(3):
+ block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
+ block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
+ block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
+ block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
+block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
+block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
+block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
+block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
+block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
+block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
+block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
+block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
+block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
+
+assert len(block_three_sd_orig) == 0
+
+block_three = ResnetDownsampleBlock2D(
+ in_channels=640,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_three.load_state_dict(block_three_sd_new)
+
+print("DOWN BLOCK FOUR")
+
+block_four_sd_orig = model.down[3].state_dict()
+block_four_sd_new = {}
+
+for i in range(3):
+ block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
+ block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
+ block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
+ block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(block_four_sd_orig) == 0
+
+block_four = ResnetDownsampleBlock2D(
+ in_channels=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_four.load_state_dict(block_four_sd_new)
+
+
+print("MID BLOCK 1")
+
+mid_block_one_sd_orig = model.mid.state_dict()
+mid_block_one_sd_new = {}
+
+for i in range(2):
+ mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(mid_block_one_sd_orig) == 0
+
+mid_block_one = UNetMidBlock2D(
+ in_channels=1024,
+ temb_channels=1280,
+ num_layers=1,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+ add_attention=False,
+)
+
+mid_block_one.load_state_dict(mid_block_one_sd_new)
+
+print("UP BLOCK ONE")
+
+up_block_one_sd_orig = model.up[-1].state_dict()
+up_block_one_sd_new = {}
+
+for i in range(4):
+ up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
+up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
+up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
+up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
+up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
+up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
+up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
+up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_one_sd_orig) == 0
+
+up_block_one = ResnetUpsampleBlock2D(
+ in_channels=1024,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_one.load_state_dict(up_block_one_sd_new)
+
+print("UP BLOCK TWO")
+
+up_block_two_sd_orig = model.up[-2].state_dict()
+up_block_two_sd_new = {}
+
+for i in range(4):
+ up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
+up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
+up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
+up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
+up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
+up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
+up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
+up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_two_sd_orig) == 0
+
+up_block_two = ResnetUpsampleBlock2D(
+ in_channels=640,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_two.load_state_dict(up_block_two_sd_new)
+
+print("UP BLOCK THREE")
+
+up_block_three_sd_orig = model.up[-3].state_dict()
+up_block_three_sd_new = {}
+
+for i in range(4):
+ up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
+up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
+up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
+up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
+up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
+up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
+up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
+up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_three_sd_orig) == 0
+
+up_block_three = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=1024,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_three.load_state_dict(up_block_three_sd_new)
+
+print("UP BLOCK FOUR")
+
+up_block_four_sd_orig = model.up[-4].state_dict()
+up_block_four_sd_new = {}
+
+for i in range(4):
+ up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
+
+assert len(up_block_four_sd_orig) == 0
+
+up_block_four = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=640,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_four.load_state_dict(up_block_four_sd_new)
+
+print("initial projection (conv_in)")
+
+conv_in_sd_orig = model.embed_image.state_dict()
+conv_in_sd_new = {}
+
+conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
+conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
+
+assert len(conv_in_sd_orig) == 0
+
+block_out_channels = [320, 640, 1024, 1024]
+
+in_channels = 7
+conv_in_kernel = 3
+conv_in_padding = (conv_in_kernel - 1) // 2
+conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+conv_in.load_state_dict(conv_in_sd_new)
+
+print("out projection (conv_out) (conv_norm_out)")
+out_channels = 6
+norm_num_groups = 32
+norm_eps = 1e-5
+act_fn = "silu"
+conv_out_kernel = 3
+conv_out_padding = (conv_out_kernel - 1) // 2
+conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+# uses torch.functional in orig
+# conv_act = get_activation(act_fn)
+conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
+
+conv_norm_out.load_state_dict(model.output.gn.state_dict())
+conv_out.load_state_dict(model.output.f.state_dict())
+
+print("timestep projection (time_proj) (time_embedding)")
+
+f1_sd = model.embed_time.f_1.state_dict()
+f2_sd = model.embed_time.f_2.state_dict()
+
+time_embedding_sd = {
+ "linear_1.weight": f1_sd.pop("weight"),
+ "linear_1.bias": f1_sd.pop("bias"),
+ "linear_2.weight": f2_sd.pop("weight"),
+ "linear_2.bias": f2_sd.pop("bias"),
+}
+
+assert len(f1_sd) == 0
+assert len(f2_sd) == 0
+
+time_embedding_type = "learned"
+num_train_timesteps = 1024
+time_embedding_dim = 1280
+
+time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+timestep_input_dim = block_out_channels[0]
+
+time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
+
+time_proj.load_state_dict(model.embed_time.emb.state_dict())
+time_embedding.load_state_dict(time_embedding_sd)
+
+print("CONVERT")
+
+time_embedding.to("cuda")
+time_proj.to("cuda")
+conv_in.to("cuda")
+
+block_one.to("cuda")
+block_two.to("cuda")
+block_three.to("cuda")
+block_four.to("cuda")
+
+mid_block_one.to("cuda")
+
+up_block_one.to("cuda")
+up_block_two.to("cuda")
+up_block_three.to("cuda")
+up_block_four.to("cuda")
+
+conv_norm_out.to("cuda")
+conv_out.to("cuda")
+
+model.time_proj = time_proj
+model.time_embedding = time_embedding
+model.embed_image = conv_in
+
+model.down[0] = block_one
+model.down[1] = block_two
+model.down[2] = block_three
+model.down[3] = block_four
+
+model.mid = mid_block_one
+
+model.up[-1] = up_block_one
+model.up[-2] = up_block_two
+model.up[-3] = up_block_three
+model.up[-4] = up_block_four
+
+model.output.gn = conv_norm_out
+model.output.f = conv_out
+
+model.converted = True
+
+sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new, "con_new.png")
+
+assert (sample_consistency_orig == sample_consistency_new).all()
+
+print("making unet")
+
+unet = UNet2DModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=(
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ up_block_types=(
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ block_out_channels=block_out_channels,
+ layers_per_block=3,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="learned",
+ num_train_timesteps=num_train_timesteps,
+ add_attention=False,
+)
+
+unet_state_dict = {}
+
+
+def add_state_dict(prefix, mod):
+ for k, v in mod.state_dict().items():
+ unet_state_dict[f"{prefix}.{k}"] = v
+
+
+add_state_dict("conv_in", conv_in)
+add_state_dict("time_proj", time_proj)
+add_state_dict("time_embedding", time_embedding)
+add_state_dict("down_blocks.0", block_one)
+add_state_dict("down_blocks.1", block_two)
+add_state_dict("down_blocks.2", block_three)
+add_state_dict("down_blocks.3", block_four)
+add_state_dict("mid_block", mid_block_one)
+add_state_dict("up_blocks.0", up_block_one)
+add_state_dict("up_blocks.1", up_block_two)
+add_state_dict("up_blocks.2", up_block_three)
+add_state_dict("up_blocks.3", up_block_four)
+add_state_dict("conv_norm_out", conv_norm_out)
+add_state_dict("conv_out", conv_out)
+
+unet.load_state_dict(unet_state_dict)
+
+print("running with diffusers unet")
+
+unet.to("cuda")
+
+decoder_consistency.ckpt = unet
+
+sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new_2, "con_new_2.png")
+
+assert (sample_consistency_orig == sample_consistency_new_2).all()
+
+print("running with diffusers model")
+
+Encoder.old_constructor = Encoder.__init__
+
+
+def new_constructor(self, **kwargs):
+ self.old_constructor(**kwargs)
+ self.constructor_arguments = kwargs
+
+
+Encoder.__init__ = new_constructor
+
+
+vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+consistency_vae = ConsistencyDecoderVAE(
+ encoder_args=vae.encoder.constructor_arguments,
+ decoder_args=unet.config,
+ scaling_factor=vae.config.scaling_factor,
+ block_out_channels=vae.config.block_out_channels,
+ latent_channels=vae.config.latent_channels,
+)
+consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
+consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
+consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
+
+consistency_vae.to(dtype=torch.float16, device="cuda")
+
+sample_consistency_new_3 = consistency_vae.decode(
+ 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
+).sample
+
+print("max difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().max())
+print("total difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
+# assert (sample_consistency_orig == sample_consistency_new_3).all()
+
+print("running with diffusers pipeline")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+)
+pipe.to("cuda")
+
+pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
+
+
+if args.save_pretrained is not None:
+ consistency_vae.save_pretrained(args.save_pretrained)
diff --git a/diffusers/scripts/convert_consistency_to_diffusers.py b/diffusers/scripts/convert_consistency_to_diffusers.py
new file mode 100755
index 0000000..0f8b4dd
--- /dev/null
+++ b/diffusers/scripts/convert_consistency_to_diffusers.py
@@ -0,0 +1,315 @@
+import argparse
+import os
+
+import torch
+
+from diffusers import (
+ CMStochasticIterativeScheduler,
+ ConsistencyModelPipeline,
+ UNet2DModel,
+)
+
+
+TEST_UNET_CONFIG = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": 1000,
+ "block_out_channels": [32, 64],
+ "attention_head_dim": 8,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+IMAGENET_64_UNET_CONFIG = {
+ "sample_size": 64,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 3,
+ "num_class_embeds": 1000,
+ "block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "scale_shift",
+ "attn_norm_num_groups": 32,
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+LSUN_256_UNET_CONFIG = {
+ "sample_size": 256,
+ "in_channels": 3,
+ "out_channels": 3,
+ "layers_per_block": 2,
+ "num_class_embeds": None,
+ "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
+ "attention_head_dim": 64,
+ "down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ "AttnDownBlock2D",
+ ],
+ "up_block_types": [
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "AttnUpBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "resnet_time_scale_shift": "default",
+ "upsample_type": "resnet",
+ "downsample_type": "resnet",
+}
+
+CD_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 40,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_IMAGENET_64_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 201,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+CT_LSUN_256_SCHEDULER_CONFIG = {
+ "num_train_timesteps": 151,
+ "sigma_min": 0.002,
+ "sigma_max": 80.0,
+}
+
+
+def str2bool(v):
+ """
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("boolean value expected")
+
+
+def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
+ new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
+ new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
+ new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
+ new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
+ new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
+ new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
+ new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
+
+ if has_skip:
+ new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
+ new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
+
+ return new_checkpoint
+
+
+def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
+ weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
+
+ new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
+ new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
+
+ new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
+ new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
+
+ new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
+ checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
+ )
+ new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
+
+ return new_checkpoint
+
+
+def con_pt_to_diffuser(checkpoint_path: str, unet_config):
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
+
+ if unet_config["num_class_embeds"] is not None:
+ new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
+
+ down_block_types = unet_config["down_block_types"]
+ layers_per_block = unet_config["layers_per_block"]
+ attention_head_dim = unet_config["attention_head_dim"]
+ channels_list = unet_config["block_out_channels"]
+ current_layer = 1
+ prev_channels = channels_list[0]
+
+ for i, layer_type in enumerate(down_block_types):
+ current_channels = channels_list[i]
+ downsample_block_has_skip = current_channels != prev_channels
+ if layer_type == "ResnetDownsampleBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ current_layer += 1
+
+ elif layer_type == "AttnDownBlock2D":
+ for j in range(layers_per_block):
+ new_prefix = f"down_blocks.{i}.resnets.{j}"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ has_skip = True if j == 0 and downsample_block_has_skip else False
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
+ new_prefix = f"down_blocks.{i}.attentions.{j}"
+ old_prefix = f"input_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(down_block_types) - 1:
+ new_prefix = f"down_blocks.{i}.downsamplers.0"
+ old_prefix = f"input_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ current_layer += 1
+
+ prev_channels = current_channels
+
+ # hardcoded the mid-block for now
+ new_prefix = "mid_block.resnets.0"
+ old_prefix = "middle_block.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ new_prefix = "mid_block.attentions.0"
+ old_prefix = "middle_block.1"
+ new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
+ new_prefix = "mid_block.resnets.1"
+ old_prefix = "middle_block.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ current_layer = 0
+ up_block_types = unet_config["up_block_types"]
+
+ for i, layer_type in enumerate(up_block_types):
+ if layer_type == "ResnetUpsampleBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.1"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+ elif layer_type == "AttnUpBlock2D":
+ for j in range(layers_per_block + 1):
+ new_prefix = f"up_blocks.{i}.resnets.{j}"
+ old_prefix = f"output_blocks.{current_layer}.0"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
+ new_prefix = f"up_blocks.{i}.attentions.{j}"
+ old_prefix = f"output_blocks.{current_layer}.1"
+ new_checkpoint = convert_attention(
+ checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
+ )
+ current_layer += 1
+
+ if i != len(up_block_types) - 1:
+ new_prefix = f"up_blocks.{i}.upsamplers.0"
+ old_prefix = f"output_blocks.{current_layer-1}.2"
+ new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
+ parser.add_argument(
+ "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
+ )
+ parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
+
+ args = parser.parse_args()
+ args.class_cond = str2bool(args.class_cond)
+
+ ckpt_name = os.path.basename(args.unet_path)
+ print(f"Checkpoint: {ckpt_name}")
+
+ # Get U-Net config
+ if "imagenet64" in ckpt_name:
+ unet_config = IMAGENET_64_UNET_CONFIG
+ elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ unet_config = LSUN_256_UNET_CONFIG
+ elif "test" in ckpt_name:
+ unet_config = TEST_UNET_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ if not args.class_cond:
+ unet_config["num_class_embeds"] = None
+
+ converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
+
+ image_unet = UNet2DModel(**unet_config)
+ image_unet.load_state_dict(converted_unet_ckpt)
+
+ # Get scheduler config
+ if "cd" in ckpt_name or "test" in ckpt_name:
+ scheduler_config = CD_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
+ scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
+ elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
+ scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
+ else:
+ raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
+
+ cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
+
+ consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
+ consistency_model.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_dance_diffusion_to_diffusers.py b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
new file mode 100755
index 0000000..ce69bfe
--- /dev/null
+++ b/diffusers/scripts/convert_dance_diffusion_to_diffusers.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+import argparse
+import math
+import os
+from copy import deepcopy
+
+import requests
+import torch
+from audio_diffusion.models import DiffusionAttnUnet1D
+from diffusion import sampling
+from torch import nn
+
+from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
+
+
+MODELS_MAP = {
+ "gwf-440k": {
+ "url": "https://model-server.zqevans2.workers.dev/gwf-440k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-small-190k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 65536,
+ },
+ "jmann-large-580k": {
+ "url": "https://model-server.zqevans2.workers.dev/jmann-large-580k.ckpt",
+ "sample_rate": 48000,
+ "sample_size": 131072,
+ },
+ "maestro-uncond-150k": {
+ "url": "https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "unlocked-uncond-250k": {
+ "url": "https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+ "honk-140k": {
+ "url": "https://model-server.zqevans2.workers.dev/honk-140k.ckpt",
+ "sample_rate": 16000,
+ "sample_size": 65536,
+ },
+}
+
+
+def alpha_sigma_to_t(alpha, sigma):
+ """Returns a timestep, given the scaling factors for the clean image and for
+ the noise."""
+ return torch.atan2(sigma, alpha) / math.pi * 2
+
+
+def get_crash_schedule(t):
+ sigma = torch.sin(t * math.pi / 2) ** 2
+ alpha = (1 - sigma**2) ** 0.5
+ return alpha_sigma_to_t(alpha, sigma)
+
+
+class Object(object):
+ pass
+
+
+class DiffusionUncond(nn.Module):
+ def __init__(self, global_args):
+ super().__init__()
+
+ self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers=4)
+ self.diffusion_ema = deepcopy(self.diffusion)
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
+
+
+def download(model_name):
+ url = MODELS_MAP[model_name]["url"]
+ r = requests.get(url, stream=True)
+
+ local_filename = f"./{model_name}.ckpt"
+ with open(local_filename, "wb") as fp:
+ for chunk in r.iter_content(chunk_size=8192):
+ fp.write(chunk)
+
+ return local_filename
+
+
+DOWN_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+}
+UP_NUM_TO_LAYER = {
+ "8": "resnets.0",
+ "9": "attentions.0",
+ "10": "resnets.1",
+ "11": "attentions.1",
+ "12": "resnets.2",
+ "13": "attentions.2",
+}
+MID_NUM_TO_LAYER = {
+ "1": "resnets.0",
+ "2": "attentions.0",
+ "3": "resnets.1",
+ "4": "attentions.1",
+ "5": "resnets.2",
+ "6": "attentions.2",
+ "8": "resnets.3",
+ "9": "attentions.3",
+ "10": "resnets.4",
+ "11": "attentions.4",
+ "12": "resnets.5",
+ "13": "attentions.5",
+}
+DEPTH_0_TO_LAYER = {
+ "0": "resnets.0",
+ "1": "resnets.1",
+ "2": "resnets.2",
+ "4": "resnets.0",
+ "5": "resnets.1",
+ "6": "resnets.2",
+}
+
+RES_CONV_MAP = {
+ "skip": "conv_skip",
+ "main.0": "conv_1",
+ "main.1": "group_norm_1",
+ "main.3": "conv_2",
+ "main.4": "group_norm_2",
+}
+
+ATTN_MAP = {
+ "norm": "group_norm",
+ "qkv_proj": ["query", "key", "value"],
+ "out_proj": ["proj_attn"],
+}
+
+
+def convert_resconv_naming(name):
+ if name.startswith("skip"):
+ return name.replace("skip", RES_CONV_MAP["skip"])
+
+ # name has to be of format main.{digit}
+ if not name.startswith("main."):
+ raise ValueError(f"ResConvBlock error with {name}")
+
+ return name.replace(name[:6], RES_CONV_MAP[name[:6]])
+
+
+def convert_attn_naming(name):
+ for key, value in ATTN_MAP.items():
+ if name.startswith(key) and not isinstance(value, list):
+ return name.replace(key, value)
+ elif name.startswith(key):
+ return [name.replace(key, v) for v in value]
+ raise ValueError(f"Attn error with {name}")
+
+
+def rename(input_string, max_depth=13):
+ string = input_string
+
+ if string.split(".")[0] == "timestep_embed":
+ return string.replace("timestep_embed", "time_proj")
+
+ depth = 0
+ if string.startswith("net.3."):
+ depth += 1
+ string = string[6:]
+ elif string.startswith("net."):
+ string = string[4:]
+
+ while string.startswith("main.7."):
+ depth += 1
+ string = string[7:]
+
+ if string.startswith("main."):
+ string = string[5:]
+
+ # mid block
+ if string[:2].isdigit():
+ layer_num = string[:2]
+ string_left = string[2:]
+ else:
+ layer_num = string[0]
+ string_left = string[1:]
+
+ if depth == max_depth:
+ new_layer = MID_NUM_TO_LAYER[layer_num]
+ prefix = "mid_block"
+ elif depth > 0 and int(layer_num) < 7:
+ new_layer = DOWN_NUM_TO_LAYER[layer_num]
+ prefix = f"down_blocks.{depth}"
+ elif depth > 0 and int(layer_num) > 7:
+ new_layer = UP_NUM_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - depth - 1}"
+ elif depth == 0:
+ new_layer = DEPTH_0_TO_LAYER[layer_num]
+ prefix = f"up_blocks.{max_depth - 1}" if int(layer_num) > 3 else "down_blocks.0"
+
+ if not string_left.startswith("."):
+ raise ValueError(f"Naming error with {input_string} and string_left: {string_left}.")
+
+ string_left = string_left[1:]
+
+ if "resnets" in new_layer:
+ string_left = convert_resconv_naming(string_left)
+ elif "attentions" in new_layer:
+ new_string_left = convert_attn_naming(string_left)
+ string_left = new_string_left
+
+ if not isinstance(string_left, list):
+ new_string = prefix + "." + new_layer + "." + string_left
+ else:
+ new_string = [prefix + "." + new_layer + "." + s for s in string_left]
+ return new_string
+
+
+def rename_orig_weights(state_dict):
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.endswith("kernel"):
+ # up- and downsample layers, don't have trainable weights
+ continue
+
+ new_k = rename(k)
+
+ # check if we need to transform from Conv => Linear for attention
+ if isinstance(new_k, list):
+ new_state_dict = transform_conv_attns(new_state_dict, new_k, v)
+ else:
+ new_state_dict[new_k] = v
+
+ return new_state_dict
+
+
+def transform_conv_attns(new_state_dict, new_k, v):
+ if len(new_k) == 1:
+ if len(v.shape) == 3:
+ # weight
+ new_state_dict[new_k[0]] = v[:, :, 0]
+ else:
+ # bias
+ new_state_dict[new_k[0]] = v
+ else:
+ # qkv matrices
+ trippled_shape = v.shape[0]
+ single_shape = trippled_shape // 3
+ for i in range(3):
+ if len(v.shape) == 3:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape, :, 0]
+ else:
+ new_state_dict[new_k[i]] = v[i * single_shape : (i + 1) * single_shape]
+ return new_state_dict
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ model_name = args.model_path.split("/")[-1].split(".")[0]
+ if not os.path.isfile(args.model_path):
+ assert (
+ model_name == args.model_path
+ ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ args.model_path = download(model_name)
+
+ sample_rate = MODELS_MAP[model_name]["sample_rate"]
+ sample_size = MODELS_MAP[model_name]["sample_size"]
+
+ config = Object()
+ config.sample_size = sample_size
+ config.sample_rate = sample_rate
+ config.latent_dim = 0
+
+ diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
+ diffusers_state_dict = diffusers_model.state_dict()
+
+ orig_model = DiffusionUncond(config)
+ orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
+ orig_model = orig_model.diffusion_ema.eval()
+ orig_model_state_dict = orig_model.state_dict()
+ renamed_state_dict = rename_orig_weights(orig_model_state_dict)
+
+ renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
+ diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
+
+ assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
+ assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
+
+ for key, value in renamed_state_dict.items():
+ assert (
+ diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
+ ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ if key == "time_proj.weight":
+ value = value.squeeze()
+
+ diffusers_state_dict[key] = value
+
+ diffusers_model.load_state_dict(diffusers_state_dict)
+
+ steps = 100
+ seed = 33
+
+ diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
+
+ generator = torch.manual_seed(seed)
+ noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
+
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
+ step_list = get_crash_schedule(t)
+
+ pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
+
+ generator = torch.manual_seed(33)
+ audio = pipe(num_inference_steps=steps, generator=generator).audios
+
+ generated = sampling.iplms_sample(orig_model, noise, step_list, {})
+ generated = generated.clamp(-1, 1)
+
+ diff_sum = (generated - audio).abs().sum()
+ diff_max = (generated - audio).abs().max()
+
+ if args.save:
+ pipe.save_pretrained(args.checkpoint_path)
+
+ print("Diff sum", diff_sum)
+ print("Diff max", diff_max)
+
+ assert diff_max < 1e-3, f"Diff max: {diff_max} is too much :-/"
+
+ print(f"Conversion for {model_name} successful!")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted model or not."
+ )
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
new file mode 100755
index 0000000..4659578
--- /dev/null
+++ b/diffusers/scripts/convert_ddpm_original_checkpoint_to_diffusers.py
@@ -0,0 +1,431 @@
+import argparse
+import json
+
+import torch
+
+from diffusers import AutoencoderKL, DDPMPipeline, DDPMScheduler, UNet2DModel, VQModel
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+ new_item = new_item.replace("block.", "resnets.")
+ new_item = new_item.replace("conv_shorcut", "conv1")
+ new_item = new_item.replace("in_shortcut", "conv_shortcut")
+ new_item = new_item.replace("temb_proj", "time_emb_proj")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0, in_mid=False):
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # In `model.mid`, the layer is called `attn`.
+ if not in_mid:
+ new_item = new_item.replace("attn", "attentions")
+ new_item = new_item.replace(".k.", ".key.")
+ new_item = new_item.replace(".v.", ".value.")
+ new_item = new_item.replace(".q.", ".query.")
+
+ new_item = new_item.replace("proj_out", "proj_attn")
+ new_item = new_item.replace("norm", "group_norm")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ if attention_paths_to_split is not None:
+ if config is None:
+ raise ValueError("Please specify the config if setting 'attention_paths_to_split' to 'True'.")
+
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config.get("num_head_channels", 1) // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape).squeeze()
+ checkpoint[path_map["key"]] = key.reshape(target_shape).squeeze()
+ checkpoint[path_map["value"]] = value.reshape(target_shape).squeeze()
+
+ for path in paths:
+ new_path = path["new"]
+
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ new_path = new_path.replace("down.", "down_blocks.")
+ new_path = new_path.replace("up.", "up_blocks.")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ if "attentions" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]].squeeze()
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def convert_ddpm_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["temb.dense.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["temb.dense.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["temb.dense.1.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["temb.dense.1.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = checkpoint["norm_out.weight"]
+ new_checkpoint["conv_norm_out.bias"] = checkpoint["norm_out.bias"]
+
+ new_checkpoint["conv_in.weight"] = checkpoint["conv_in.weight"]
+ new_checkpoint["conv_in.bias"] = checkpoint["conv_in.bias"]
+ new_checkpoint["conv_out.weight"] = checkpoint["conv_out.weight"]
+ new_checkpoint["conv_out.bias"] = checkpoint["conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"down.{i}.downsample.op.weight"
+ ]
+ new_checkpoint[f"down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[f"down.{i}.downsample.op.bias"]
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.weight'] = checkpoint[f'down.{i}.downsample.conv.weight']
+ # new_checkpoint[f'down_blocks.{i}.downsamplers.0.op.bias'] = checkpoint[f'down.{i}.downsample.conv.bias']
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[f"up.{i}.upsample.conv.bias"]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 2).split(".")[:2]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ return new_checkpoint
+
+
+def convert_vq_autoenc_checkpoint(checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_norm_out.weight"] = checkpoint["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = checkpoint["encoder.norm_out.bias"]
+
+ new_checkpoint["encoder.conv_in.weight"] = checkpoint["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = checkpoint["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = checkpoint["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = checkpoint["encoder.conv_out.bias"]
+
+ new_checkpoint["decoder.conv_norm_out.weight"] = checkpoint["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = checkpoint["decoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = checkpoint["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = checkpoint["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = checkpoint["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = checkpoint["decoder.conv_out.bias"]
+
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "down" in layer})
+ down_blocks = {
+ layer_id: [key for key in checkpoint if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in checkpoint if "up" in layer})
+ up_blocks = {layer_id: [key for key in checkpoint if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
+
+ for i in range(num_down_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+
+ if any("downsample" in layer for layer in down_blocks[i]):
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.weight"
+ ]
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = checkpoint[
+ f"encoder.down.{i}.downsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in down_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in down_blocks[i] if f"block.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint)
+
+ if any("attn" in layer for layer in down_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in down_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in down_blocks[i] if f"attn.{layer_id}" in key]
+ for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"]):
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, config=config)
+
+ mid_block_1_layers = [key for key in checkpoint if "mid.block_1" in key]
+ mid_block_2_layers = [key for key in checkpoint if "mid.block_2" in key]
+ mid_attn_1_layers = [key for key in checkpoint if "mid.attn_1" in key]
+
+ # Mid new 2
+ paths = renew_resnet_paths(mid_block_1_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_1", "new": "resnets.0"}],
+ )
+
+ paths = renew_resnet_paths(mid_block_2_layers)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "block_2", "new": "resnets.1"}],
+ )
+
+ paths = renew_attention_paths(mid_attn_1_layers, in_mid=True)
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ checkpoint,
+ additional_replacements=[{"old": "mid.", "new": "mid_new_2."}, {"old": "attn_1", "new": "attentions.0"}],
+ )
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+
+ if any("upsample" in layer for layer in up_blocks[i]):
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.weight"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{block_id}.upsamplers.0.conv.bias"] = checkpoint[
+ f"decoder.up.{i}.upsample.conv.bias"
+ ]
+
+ if any("block" in layer for layer in up_blocks[i]):
+ num_blocks = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "block" in layer}
+ )
+ blocks = {
+ layer_id: [key for key in up_blocks[i] if f"block.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_blocks > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_resnet_paths(blocks[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ if any("attn" in layer for layer in up_blocks[i]):
+ num_attn = len(
+ {".".join(shave_segments(layer, 3).split(".")[:3]) for layer in up_blocks[i] if "attn" in layer}
+ )
+ attns = {
+ layer_id: [key for key in up_blocks[i] if f"attn.{layer_id}" in key] for layer_id in range(num_blocks)
+ }
+
+ if num_attn > 0:
+ for j in range(config["layers_per_block"] + 1):
+ replace_indices = {"old": f"up_blocks.{i}", "new": f"up_blocks.{block_id}"}
+ paths = renew_attention_paths(attns[j])
+ assign_to_checkpoint(paths, new_checkpoint, checkpoint, additional_replacements=[replace_indices])
+
+ new_checkpoint = {k.replace("mid_new_2", "mid_block"): v for k, v in new_checkpoint.items()}
+ new_checkpoint["quant_conv.weight"] = checkpoint["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = checkpoint["quant_conv.bias"]
+ if "quantize.embedding.weight" in checkpoint:
+ new_checkpoint["quantize.embedding.weight"] = checkpoint["quantize.embedding.weight"]
+ new_checkpoint["post_quant_conv.weight"] = checkpoint["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = checkpoint["post_quant_conv.bias"]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the architecture.",
+ )
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+
+ args = parser.parse_args()
+ checkpoint = torch.load(args.checkpoint_path)
+
+ with open(args.config_file) as f:
+ config = json.loads(f.read())
+
+ # unet case
+ key_prefix_set = {key.split(".")[0] for key in checkpoint.keys()}
+ if "encoder" in key_prefix_set and "decoder" in key_prefix_set:
+ converted_checkpoint = convert_vq_autoenc_checkpoint(checkpoint, config)
+ else:
+ converted_checkpoint = convert_ddpm_checkpoint(checkpoint, config)
+
+ if "ddpm" in config:
+ del config["ddpm"]
+
+ if config["_class_name"] == "VQModel":
+ model = VQModel(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ elif config["_class_name"] == "AutoencoderKL":
+ model = AutoencoderKL(**config)
+ model.load_state_dict(converted_checkpoint)
+ model.save_pretrained(args.dump_path)
+ else:
+ model = UNet2DModel(**config)
+ model.load_state_dict(converted_checkpoint)
+
+ scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
+
+ pipe = DDPMPipeline(unet=model, scheduler=scheduler)
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py
new file mode 100755
index 0000000..dfb3871
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_sdxl_lora_to_webui.py
@@ -0,0 +1,56 @@
+# Script for converting a Hugging Face Diffusers trained SDXL LoRAs to Kohya format
+# This means that you can input your diffusers-trained LoRAs and
+# Get the output to work with WebUIs such as AUTOMATIC1111, ComfyUI, SD.Next and others.
+
+# To get started you can find some cool `diffusers` trained LoRAs such as this cute Corgy
+# https://huggingface.co/ignasbud/corgy_dog_LoRA/, download its `pytorch_lora_weights.safetensors` file
+# and run the script:
+# python convert_diffusers_sdxl_lora_to_webui.py --input_lora pytorch_lora_weights.safetensors --output_lora corgy.safetensors
+# now you can use corgy.safetensors in your WebUI of choice!
+
+# To train your own, here are some diffusers training scripts and utils that you can use and then convert:
+# LoRA Ease - no code SDXL Dreambooth LoRA trainer: https://huggingface.co/spaces/multimodalart/lora-ease
+# Dreambooth Advanced Training Script - state of the art techniques such as pivotal tuning and prodigy optimizer:
+# - Script: https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+# - Colab (only on Pro): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb
+# Canonical diffusers training scripts:
+# - Script: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
+# - Colab (runs on free tier): https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb
+
+import argparse
+import os
+
+from safetensors.torch import load_file, save_file
+
+from diffusers.utils import convert_all_state_dict_to_peft, convert_state_dict_to_kohya
+
+
+def convert_and_save(input_lora, output_lora=None):
+ if output_lora is None:
+ base_name = os.path.splitext(input_lora)[0]
+ output_lora = f"{base_name}_webui.safetensors"
+
+ diffusers_state_dict = load_file(input_lora)
+ peft_state_dict = convert_all_state_dict_to_peft(diffusers_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, output_lora)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert LoRA model to PEFT and then to Kohya format.")
+ parser.add_argument(
+ "--input_lora",
+ type=str,
+ required=True,
+ help="Path to the input LoRA model file in the diffusers format.",
+ )
+ parser.add_argument(
+ "--output_lora",
+ type=str,
+ required=False,
+ help="Path for the converted LoRA (safetensors format for AUTOMATIC1111, ComfyUI, etc.). Optional, defaults to input name with a _webui suffix.",
+ )
+
+ args = parser.parse_args()
+
+ convert_and_save(args.input_lora, args.output_lora)
diff --git a/diffusers/scripts/convert_diffusers_to_original_sdxl.py b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
new file mode 100755
index 0000000..648d037
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_sdxl.py
@@ -0,0 +1,350 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ # the following are for sdxl
+ ("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
+ ("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
+ ("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
+ ("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(3):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i > 0:
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(4):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i < 2:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {sd_name: unet_state_dict[hf_name] for hf_name, sd_name in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ # the following are for SDXL
+ ("q.", "to_q."),
+ ("k.", "to_k."),
+ ("v.", "to_v."),
+ ("proj_out.", "to_out.0."),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ if not w.ndim == 1:
+ return w.reshape(*w.shape, 1, 1)
+ else:
+ return w
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("transformer.resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "text_model.final_layer_norm."),
+ ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_openclip_text_enc_state_dict(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_openai_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ if osp.exists(text_enc_2_path):
+ text_enc_2_dict = load_file(text_enc_2_path, device="cpu")
+ else:
+ text_enc_2_path = osp.join(args.model_path, "text_encoder_2", "pytorch_model.bin")
+ text_enc_2_dict = torch.load(text_enc_2_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Convert text encoder 1
+ text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Convert text encoder 2
+ text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
+ text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
+ # We call the `.T.contiguous()` to match what's done in
+ # https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
+ text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
+ "conditioner.embedders.1.model.text_projection.weight"
+ ).T.contiguous()
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
+
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
new file mode 100755
index 0000000..d1b7df0
--- /dev/null
+++ b/diffusers/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -0,0 +1,353 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+from safetensors.torch import load_file, save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+]
+
+# This is probably not the most ideal solution, but it does work.
+vae_extra_conversion_map = [
+ ("to_q", "q"),
+ ("to_k", "k"),
+ ("to_v", "v"),
+ ("to_out.0", "proj_out"),
+]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ if not w.ndim == 1:
+ return w.reshape(*w.shape, 1, 1)
+ else:
+ return w
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ keys_to_rename = {}
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ for weight_name, real_weight_name in vae_extra_conversion_map:
+ if f"mid.attn_1.{weight_name}.weight" in k or f"mid.attn_1.{weight_name}.bias" in k:
+ keys_to_rename[k] = k.replace(weight_name, real_weight_name)
+ for k, v in keys_to_rename.items():
+ if k in new_state_dict:
+ print(f"Renaming {k} to {v}")
+ new_state_dict[v] = reshape_weight_for_sd(new_state_dict[k])
+ del new_state_dict[k]
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_text_enc_state_dict_v20(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ # Path for safetensors
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
+
+ # Load models from safetensors if it exists, if it doesn't pytorch
+ if osp.exists(unet_path):
+ unet_state_dict = load_file(unet_path, device="cpu")
+ else:
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+
+ if osp.exists(vae_path):
+ vae_state_dict = load_file(vae_path, device="cpu")
+ else:
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+
+ if osp.exists(text_enc_path):
+ text_enc_dict = load_file(text_enc_path, device="cpu")
+ else:
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
+
+ if is_v20_model:
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
+ else:
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
diff --git a/diffusers/scripts/convert_dit_to_diffusers.py b/diffusers/scripts/convert_dit_to_diffusers.py
new file mode 100755
index 0000000..dc127f6
--- /dev/null
+++ b/diffusers/scripts/convert_dit_to_diffusers.py
@@ -0,0 +1,162 @@
+import argparse
+import os
+
+import torch
+from torchvision.datasets.utils import download_url
+
+from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
+
+
+pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
+
+
+def download_model(model_name):
+ """
+ Downloads a pre-trained DiT model from the web.
+ """
+ local_path = f"pretrained_models/{model_name}"
+ if not os.path.isfile(local_path):
+ os.makedirs("pretrained_models", exist_ok=True)
+ web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
+ download_url(web_path, "pretrained_models")
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
+ return model
+
+
+def main(args):
+ state_dict = download_model(pretrained_models[args.image_size])
+
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ for depth in range(28):
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
+ "t_embedder.mlp.0.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
+ "t_embedder.mlp.0.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
+ "t_embedder.mlp.2.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
+ "t_embedder.mlp.2.bias"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
+ "y_embedder.embedding_table.weight"
+ ]
+
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
+ f"blocks.{depth}.adaLN_modulation.1.bias"
+ ]
+
+ q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
+
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
+ f"blocks.{depth}.attn.proj.weight"
+ ]
+ state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
+
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
+ state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
+
+ state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
+ state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
+ state_dict.pop(f"blocks.{depth}.attn.proj.weight")
+ state_dict.pop(f"blocks.{depth}.attn.proj.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
+ state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
+
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("y_embedder.embedding_table.weight")
+
+ state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
+ state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
+ state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
+
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # DiT XL/2
+ transformer = Transformer2DModel(
+ sample_size=args.image_size // 8,
+ num_layers=28,
+ attention_head_dim=72,
+ in_channels=4,
+ out_channels=8,
+ patch_size=2,
+ attention_bias=True,
+ num_attention_heads=16,
+ activation_fn="gelu-approximate",
+ num_embeds_ada_norm=1000,
+ norm_type="ada_norm_zero",
+ norm_elementwise_affine=False,
+ )
+ transformer.load_state_dict(state_dict, strict=True)
+
+ scheduler = DDIMScheduler(
+ num_train_timesteps=1000,
+ beta_schedule="linear",
+ prediction_type="epsilon",
+ clip_sample=False,
+ )
+
+ vae = AutoencoderKL.from_pretrained(args.vae_model)
+
+ pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
+
+ if args.save:
+ pipeline.save_pretrained(args.checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--image_size",
+ default=256,
+ type=int,
+ required=False,
+ help="Image size of pretrained model, either 256 or 512.",
+ )
+ parser.add_argument(
+ "--vae_model",
+ default="stabilityai/sd-vae-ft-ema",
+ type=str,
+ required=False,
+ help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
+ )
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_flux_to_diffusers.py b/diffusers/scripts/convert_flux_to_diffusers.py
new file mode 100755
index 0000000..05a1da2
--- /dev/null
+++ b/diffusers/scripts/convert_flux_to_diffusers.py
@@ -0,0 +1,303 @@
+import argparse
+from contextlib import nullcontext
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers import AutoencoderKL, FluxTransformer2DModel
+from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# Transformer
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "flux1-schnell.sft"
+--output_path "flux-schnell" \
+--transformer
+"""
+
+"""
+# VAE
+
+python scripts/convert_flux_to_diffusers.py \
+--original_state_dict_repo_id "black-forest-labs/FLUX.1-schnell" \
+--filename "ae.sft"
+--output_path "flux-schnell" \
+--vae
+"""
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="flux.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--vae", action="store_true")
+parser.add_argument("--transformer", action="store_true")
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_flux_transformer_checkpoint_to_diffusers(
+ original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
+):
+ converted_state_dict = {}
+
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_in.out_layer.bias"
+ )
+
+ ## time_text_embed.text_embedder <- vector_in
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "vector_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "vector_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "vector_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "vector_in.out_layer.bias"
+ )
+
+ # guidance
+ has_guidance = any("guidance" in k for k in original_state_dict)
+ if has_guidance:
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
+ "guidance_in.in_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
+ "guidance_in.in_layer.bias"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
+ "guidance_in.out_layer.weight"
+ )
+ converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
+ "guidance_in.out_layer.bias"
+ )
+
+ # context_embedder
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # norms.
+ ## norm1
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.bias"
+ )
+ ## norm1_context
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.bias"
+ )
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
+ )
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.2.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.2.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transfomer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.bias"
+ )
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+ q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.bias"
+ )
+
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
+ )
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+ has_guidance = any("guidance" in k for k in original_ckpt)
+
+ if args.transformer:
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+ converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
+ original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
+ )
+ transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ print(
+ f"Saving Flux Transformer in Diffusers format. Variant: {'guidance-distilled' if has_guidance else 'timestep-distilled'}"
+ )
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if args.vae:
+ config = AutoencoderKL.load_config("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
+ vae = AutoencoderKL.from_config(config, scaling_factor=0.3611, shift_factor=0.1159).to(torch.bfloat16)
+
+ converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/diffusers/scripts/convert_gligen_to_diffusers.py b/diffusers/scripts/convert_gligen_to_diffusers.py
new file mode 100755
index 0000000..83c1f92
--- /dev/null
+++ b/diffusers/scripts/convert_gligen_to_diffusers.py
@@ -0,0 +1,581 @@
+import argparse
+import re
+
+import torch
+import yaml
+from transformers import (
+ CLIPProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ DDIMScheduler,
+ StableDiffusionGLIGENPipeline,
+ StableDiffusionGLIGENTextImagePipeline,
+ UNet2DConditionModel,
+)
+from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
+ assign_to_checkpoint,
+ conv_attn_to_linear,
+ protected,
+ renew_attention_paths,
+ renew_resnet_paths,
+ renew_vae_attention_paths,
+ renew_vae_resnet_paths,
+ shave_segments,
+ textenc_conversion_map,
+ textenc_pattern,
+)
+
+
+def convert_open_clip_checkpoint(checkpoint):
+ checkpoint = checkpoint["text_encoder"]
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+
+ keys = list(checkpoint.keys())
+
+ text_model_dict = {}
+
+ if "cond_stage_model.model.text_projection" in checkpoint:
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
+ else:
+ d_model = 1024
+
+ for key in keys:
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
+ continue
+ if key in textenc_conversion_map:
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
+ # if key.startswith("cond_stage_model.model.transformer."):
+ new_key = key[len("transformer.") :]
+ if new_key.endswith(".in_proj_weight"):
+ new_key = new_key[: -len(".in_proj_weight")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
+ elif new_key.endswith(".in_proj_bias"):
+ new_key = new_key[: -len(".in_proj_bias")]
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
+ else:
+ if key != "transformer.text_model.embeddings.position_ids":
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
+
+ text_model_dict[new_key] = checkpoint[key]
+
+ if key == "transformer.text_model.embeddings.token_embedding.weight":
+ text_model_dict["text_model.embeddings.token_embedding.weight"] = checkpoint[key]
+
+ text_model_dict.pop("text_model.embeddings.transformer.text_model.embeddings.token_embedding.weight")
+
+ text_model.load_state_dict(text_model_dict)
+
+ return text_model
+
+
+def convert_gligen_vae_checkpoint(checkpoint, config):
+ checkpoint = checkpoint["autoencoder"]
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for key in new_checkpoint.keys():
+ if "encoder.mid_block.attentions.0" in key or "decoder.mid_block.attentions.0" in key:
+ if "query" in key:
+ new_checkpoint[key.replace("query", "to_q")] = new_checkpoint.pop(key)
+ if "value" in key:
+ new_checkpoint[key.replace("value", "to_v")] = new_checkpoint.pop(key)
+ if "key" in key:
+ new_checkpoint[key.replace("key", "to_k")] = new_checkpoint.pop(key)
+ if "proj_attn" in key:
+ new_checkpoint[key.replace("proj_attn", "to_out.0")] = new_checkpoint.pop(key)
+
+ return new_checkpoint
+
+
+def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
+ unet_state_dict = {}
+ checkpoint = checkpoint["model"]
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ print(f"Checkpoint {path} has bot EMA and non-EMA weights.")
+ print(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ print(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+ for key in keys:
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ for key in keys:
+ if "position_net" in key:
+ new_checkpoint[key] = unet_state_dict[key]
+
+ return new_checkpoint
+
+
+def create_vae_config(original_config, image_size: int):
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
+ _ = original_config["autoencoder"]["params"]["embed_dim"]
+
+ block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = {
+ "sample_size": image_size,
+ "in_channels": vae_params["in_channels"],
+ "out_channels": vae_params["out_ch"],
+ "down_block_types": tuple(down_block_types),
+ "up_block_types": tuple(up_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "latent_channels": vae_params["z_channels"],
+ "layers_per_block": vae_params["num_res_blocks"],
+ }
+
+ return config
+
+
+def create_unet_config(original_config, image_size: int, attention_type):
+ unet_params = original_config["model"]["params"]
+ vae_params = original_config["autoencoder"]["params"]["ddconfig"]
+
+ block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
+
+ head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
+ use_linear_projection = (
+ unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
+ )
+ if use_linear_projection:
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ config = {
+ "sample_size": image_size // vae_scale_factor,
+ "in_channels": unet_params["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": unet_params["num_res_blocks"],
+ "cross_attention_dim": unet_params["context_dim"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "attention_type": attention_type,
+ }
+
+ return config
+
+
+def convert_gligen_to_diffusers(
+ checkpoint_path: str,
+ original_config_file: str,
+ attention_type: str,
+ image_size: int = 512,
+ extract_ema: bool = False,
+ num_in_channels: int = None,
+ device: str = None,
+):
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+ else:
+ checkpoint = torch.load(checkpoint_path, map_location=device)
+
+ if "global_step" in checkpoint:
+ checkpoint["global_step"]
+ else:
+ print("global_step key not found in model")
+
+ original_config = yaml.safe_load(original_config_file)
+
+ if num_in_channels is not None:
+ original_config["model"]["params"]["in_channels"] = num_in_channels
+
+ num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
+ beta_start = original_config["diffusion"]["params"]["linear_start"]
+ beta_end = original_config["diffusion"]["params"]["linear_end"]
+
+ scheduler = DDIMScheduler(
+ beta_end=beta_end,
+ beta_schedule="scaled_linear",
+ beta_start=beta_start,
+ num_train_timesteps=num_train_timesteps,
+ steps_offset=1,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ prediction_type="epsilon",
+ )
+
+ # Convert the UNet2DConditionalModel model
+ unet_config = create_unet_config(original_config, image_size, attention_type)
+ unet = UNet2DConditionModel(**unet_config)
+
+ converted_unet_checkpoint = convert_gligen_unet_checkpoint(
+ checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ # Convert the VAE model
+ vae_config = create_vae_config(original_config, image_size)
+ converted_vae_checkpoint = convert_gligen_vae_checkpoint(checkpoint, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+
+ # Convert the text model
+ text_encoder = convert_open_clip_checkpoint(checkpoint)
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
+
+ if attention_type == "gated-text-image":
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
+
+ pipe = StableDiffusionGLIGENTextImagePipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ processor=processor,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ elif attention_type == "gated":
+ pipe = StableDiffusionGLIGENPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+
+ return pipe
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--original_config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The YAML config file corresponding to the gligen architecture.",
+ )
+ parser.add_argument(
+ "--num_in_channels",
+ default=None,
+ type=int,
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
+ )
+ parser.add_argument(
+ "--extract_ema",
+ action="store_true",
+ help=(
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
+ ),
+ )
+ parser.add_argument(
+ "--attention_type",
+ default=None,
+ type=str,
+ required=True,
+ help="Type of attention ex: gated or gated-text-image",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--device", type=str, help="Device to use.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+
+ args = parser.parse_args()
+
+ pipe = convert_gligen_to_diffusers(
+ checkpoint_path=args.checkpoint_path,
+ original_config_file=args.original_config_file,
+ attention_type=args.attention_type,
+ extract_ema=args.extract_ema,
+ num_in_channels=args.num_in_channels,
+ device=args.device,
+ )
+
+ if args.half:
+ pipe.to(dtype=torch.float16)
+
+ pipe.save_pretrained(args.dump_path)
diff --git a/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py
new file mode 100755
index 0000000..1c83836
--- /dev/null
+++ b/diffusers/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -0,0 +1,241 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DControlNetModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ "Please load from the following keys:{state_dict.keys()}"
+ )
+ device = "cuda"
+
+ model_config = HunyuanDiT2DControlNetModel.load_config(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
+ )
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ print(model_config)
+
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 19
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # after_proj_list -> controlnet_blocks
+ state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
+ state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
+ state_dict.pop(f"after_proj_list.{i}.weight")
+ state_dict.pop(f"after_proj_list.{i}.bias")
+
+ # before_proj -> input_block
+ state_dict["input_block.weight"] = state_dict["before_proj.weight"]
+ state_dict["input_block.bias"] = state_dict["before_proj.bias"]
+ state_dict.pop("before_proj.weight")
+ state_dict.pop("before_proj.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ if args.save:
+ model.save_pretrained(args.output_checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_hunyuandit_to_diffusers.py b/diffusers/scripts/convert_hunyuandit_to_diffusers.py
new file mode 100755
index 0000000..da3af83
--- /dev/null
+++ b/diffusers/scripts/convert_hunyuandit_to_diffusers.py
@@ -0,0 +1,267 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ f"Please load from the following keys:{state_dict.keys()}"
+ )
+
+ device = "cuda"
+ model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+
+ # input_size -> sample_size, text_dim -> cross_attention_dim
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 40
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # model.final_adaLN_modulation.1 -> norm_out.linear
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
+ state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # final_linear -> proj_out
+ state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ from diffusers import HunyuanDiTPipeline
+
+ if args.use_style_cond_and_image_meta_size:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ else:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ pipe.to("cuda")
+ pipe.to(dtype=torch.float32)
+
+ if args.save:
+ pipe.save_pretrained(args.output_checkpoint_path)
+
+ # ### NOTE: HunyuanDiT supports both Chinese and English inputs
+ prompt = "一个宇航员在骑马"
+ # prompt = "An astronaut riding a horse"
+ generator = torch.Generator(device="cuda").manual_seed(0)
+ image = pipe(
+ height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
+ ).images[0]
+
+ image.save("img.png")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffusers/scripts/convert_i2vgen_to_diffusers.py b/diffusers/scripts/convert_i2vgen_to_diffusers.py
new file mode 100755
index 0000000..b9e3ff2
--- /dev/null
+++ b/diffusers/scripts/convert_i2vgen_to_diffusers.py
@@ -0,0 +1,510 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Conversion script for the LDM checkpoints."""
+
+import argparse
+
+import torch
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers import DDIMScheduler, I2VGenXLPipeline, I2VGenXLUNet, StableDiffusionPipeline
+
+
+CLIP_ID = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ weight = old_checkpoint[path["old"]]
+ names = ["proj_attn.weight"]
+ names_2 = ["proj_out.weight", "proj_in.weight"]
+ if any(k in new_path for k in names):
+ checkpoint[new_path] = weight[:, :, 0]
+ elif any(k in new_path for k in names_2) and len(weight.shape) > 2 and ".attentions." not in new_path:
+ checkpoint[new_path] = weight[:, :, 0]
+ else:
+ checkpoint[new_path] = weight
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_temp_conv_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ mapping.append({"old": old_item, "new": old_item})
+
+ return mapping
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ if "temopral_conv" not in old_item:
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ unet_key = "model.diffusion_model."
+
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
+ print(
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
+ )
+ for key in keys:
+ if key.startswith("model.diffusion_model"):
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
+ else:
+ if sum(k.startswith("model_ema") for k in keys) > 100:
+ print(
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
+ )
+
+ for key in keys:
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ additional_embedding_substrings = [
+ "local_image_concat",
+ "context_embedding",
+ "local_image_embedding",
+ "fps_embedding",
+ ]
+ for k in unet_state_dict:
+ if any(substring in k for substring in additional_embedding_substrings):
+ diffusers_key = k.replace("local_image_concat", "image_latents_proj_in").replace(
+ "local_image_embedding", "image_latents_context_embedding"
+ )
+ new_checkpoint[diffusers_key] = unet_state_dict[k]
+
+ # temporal encoder.
+ new_checkpoint["image_latents_temporal_encoder.norm1.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.norm.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.norm1.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.norm.bias"
+ ]
+
+ # attention
+ qkv = unet_state_dict["local_temporal_encoder.layers.0.0.fn.to_qkv.weight"]
+ q, k, v = torch.chunk(qkv, 3, dim=0)
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_q.weight"] = q
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_k.weight"] = k
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_v.weight"] = v
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.attn1.to_out.0.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.0.fn.to_out.0.bias"
+ ]
+
+ # feedforward
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.0.0.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.0.proj.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.0.0.bias"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.weight"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.2.weight"
+ ]
+ new_checkpoint["image_latents_temporal_encoder.ff.net.2.bias"] = unet_state_dict[
+ "local_temporal_encoder.layers.0.1.net.2.bias"
+ ]
+
+ if "class_embed_type" in config:
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ first_temp_attention = [v for v in unet_state_dict if v.startswith("input_blocks.0.1")]
+ paths = renew_attention_paths(first_temp_attention)
+ meta_path = {"old": "input_blocks.0.1", "new": "transformer_in"}
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+ temp_attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.2" in key]
+
+ if f"input_blocks.{i}.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
+ paths = renew_temp_conv_paths(temporal_convs)
+ meta_path = {
+ "old": f"input_blocks.{i}.0.temopral_conv",
+ "new": f"down_blocks.{block_id}.temp_convs.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(temp_attentions):
+ paths = renew_attention_paths(temp_attentions)
+ meta_path = {
+ "old": f"input_blocks.{i}.2",
+ "new": f"down_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ temporal_convs_0 = [key for key in resnet_0 if "temopral_conv" in key]
+ attentions = middle_blocks[1]
+ temp_attentions = middle_blocks[2]
+ resnet_1 = middle_blocks[3]
+ temporal_convs_1 = [key for key in resnet_1 if "temopral_conv" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ meta_path = {"old": "middle_block.0", "new": "mid_block.resnets.0"}
+ assign_to_checkpoint(
+ resnet_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ temp_conv_0_paths = renew_temp_conv_paths(temporal_convs_0)
+ meta_path = {"old": "middle_block.0.temopral_conv", "new": "mid_block.temp_convs.0"}
+ assign_to_checkpoint(
+ temp_conv_0_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ meta_path = {"old": "middle_block.3", "new": "mid_block.resnets.1"}
+ assign_to_checkpoint(
+ resnet_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ temp_conv_1_paths = renew_temp_conv_paths(temporal_convs_1)
+ meta_path = {"old": "middle_block.3.temopral_conv", "new": "mid_block.temp_convs.1"}
+ assign_to_checkpoint(
+ temp_conv_1_paths, new_checkpoint, unet_state_dict, config=config, additional_replacements=[meta_path]
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temp_attentions_paths = renew_attention_paths(temp_attentions)
+ meta_path = {"old": "middle_block.2", "new": "mid_block.temp_attentions.0"}
+ assign_to_checkpoint(
+ temp_attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+ temp_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ temporal_convs = [key for key in resnets if "temopral_conv" in key]
+ paths = renew_temp_conv_paths(temporal_convs)
+ meta_path = {
+ "old": f"output_blocks.{i}.0.temopral_conv",
+ "new": f"up_blocks.{block_id}.temp_convs.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(temp_attentions):
+ paths = renew_attention_paths(temp_attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.2",
+ "new": f"up_blocks.{block_id}.temp_attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ temopral_conv_paths = [l for l in output_block_layers if "temopral_conv" in l]
+ for path in temopral_conv_paths:
+ pruned_path = path.split("temopral_conv.")[-1]
+ old_path = ".".join(["output_blocks", str(i), str(block_id), "temopral_conv", pruned_path])
+ new_path = ".".join(["up_blocks", str(block_id), "temp_convs", str(layer_in_block_id), pruned_path])
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ return new_checkpoint
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--unet_checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--push_to_hub", action="store_true")
+ args = parser.parse_args()
+
+ # UNet
+ unet_checkpoint = torch.load(args.unet_checkpoint_path, map_location="cpu")
+ unet_checkpoint = unet_checkpoint["state_dict"]
+ unet = I2VGenXLUNet(sample_size=32)
+
+ converted_ckpt = convert_ldm_unet_checkpoint(unet_checkpoint, unet.config)
+
+ diff_0 = set(unet.state_dict().keys()) - set(converted_ckpt.keys())
+ diff_1 = set(converted_ckpt.keys()) - set(unet.state_dict().keys())
+
+ assert len(diff_0) == len(diff_1) == 0, "Converted weights don't match"
+
+ unet.load_state_dict(converted_ckpt, strict=True)
+
+ # vae
+ temp_pipe = StableDiffusionPipeline.from_single_file(
+ "https://huggingface.co/ali-vilab/i2vgen-xl/blob/main/models/v2-1_512-ema-pruned.ckpt"
+ )
+ vae = temp_pipe.vae
+ del temp_pipe
+
+ # text encoder and tokenizer
+ text_encoder = CLIPTextModel.from_pretrained(CLIP_ID)
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_ID)
+
+ # image encoder and feature extractor
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_ID)
+ feature_extractor = CLIPImageProcessor.from_pretrained(CLIP_ID)
+
+ # scheduler
+ # https://github.com/ali-vilab/i2vgen-xl/blob/main/configs/i2vgen_xl_train.yaml
+ scheduler = DDIMScheduler(
+ beta_schedule="squaredcos_cap_v2",
+ rescale_betas_zero_snr=True,
+ set_alpha_to_one=True,
+ clip_sample=False,
+ steps_offset=1,
+ timestep_spacing="leading",
+ prediction_type="v_prediction",
+ )
+
+ # final
+ pipeline = I2VGenXLPipeline(
+ unet=unet,
+ vae=vae,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ )
+
+ pipeline.save_pretrained(args.dump_path, push_to_hub=args.push_to_hub)
diff --git a/diffusers/scripts/convert_if.py b/diffusers/scripts/convert_if.py
new file mode 100755
index 0000000..85c739c
--- /dev/null
+++ b/diffusers/scripts/convert_if.py
@@ -0,0 +1,1250 @@
+import argparse
+import inspect
+import os
+
+import numpy as np
+import torch
+import yaml
+from torch.nn import functional as F
+from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
+
+from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet2DConditionModel
+from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_2", required=False, default=None, type=str)
+
+ parser.add_argument("--dump_path_stage_3", required=False, default=None, type=str)
+
+ parser.add_argument("--unet_config", required=False, default=None, type=str, help="Path to unet config file")
+
+ parser.add_argument(
+ "--unet_checkpoint_path", required=False, default=None, type=str, help="Path to unet checkpoint file"
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_2",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 2 unet checkpoint file",
+ )
+
+ parser.add_argument(
+ "--unet_checkpoint_path_stage_3",
+ required=False,
+ default=None,
+ type=str,
+ help="Path to stage 3 unet checkpoint file",
+ )
+
+ parser.add_argument("--p_head_path", type=str, required=True)
+
+ parser.add_argument("--w_head_path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def main(args):
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
+ text_encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
+
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = convert_safety_checker(p_head_path=args.p_head_path, w_head_path=args.w_head_path)
+
+ if args.unet_config is not None and args.unet_checkpoint_path is not None and args.dump_path is not None:
+ convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args)
+
+ if args.unet_checkpoint_path_stage_2 is not None and args.dump_path_stage_2 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=2)
+
+ if args.unet_checkpoint_path_stage_3 is not None and args.dump_path_stage_3 is not None:
+ convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage=3)
+
+
+def convert_stage_1_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args):
+ unet = get_stage_1_unet(args.unet_config, args.unet_checkpoint_path)
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.5,
+ )
+
+ pipe = IFPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(args.dump_path)
+
+
+def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safety_checker, args, stage):
+ if stage == 2:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_2
+ sample_size = None
+ dump_path = args.dump_path_stage_2
+ elif stage == 3:
+ unet_checkpoint_path = args.unet_checkpoint_path_stage_3
+ sample_size = 1024
+ dump_path = args.dump_path_stage_3
+ else:
+ assert False
+
+ unet = get_super_res_unet(unet_checkpoint_path, verify_param_count=False, sample_size=sample_size)
+
+ image_noising_scheduler = DDPMScheduler(
+ beta_schedule="squaredcos_cap_v2",
+ )
+
+ scheduler = DDPMScheduler(
+ variance_type="learned_range",
+ beta_schedule="squaredcos_cap_v2",
+ prediction_type="epsilon",
+ thresholding=True,
+ dynamic_thresholding_ratio=0.95,
+ sample_max_value=1.0,
+ )
+
+ pipe = IFSuperResolutionPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ scheduler=scheduler,
+ image_noising_scheduler=image_noising_scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ requires_safety_checker=True,
+ )
+
+ pipe.save_pretrained(dump_path)
+
+
+def get_stage_1_unet(unet_config, unet_checkpoint_path):
+ original_unet_config = yaml.safe_load(unet_config)
+ original_unet_config = original_unet_config["params"]
+
+ unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ unet_checkpoint = torch.load(unet_checkpoint_path, map_location=device)
+
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+
+ unet.load_state_dict(converted_unet_checkpoint)
+
+ return unet
+
+
+def convert_safety_checker(p_head_path, w_head_path):
+ state_dict = {}
+
+ # p head
+
+ p_head = np.load(p_head_path)
+
+ p_head_weights = p_head["weights"]
+ p_head_weights = torch.from_numpy(p_head_weights)
+ p_head_weights = p_head_weights.unsqueeze(0)
+
+ p_head_biases = p_head["biases"]
+ p_head_biases = torch.from_numpy(p_head_biases)
+ p_head_biases = p_head_biases.unsqueeze(0)
+
+ state_dict["p_head.weight"] = p_head_weights
+ state_dict["p_head.bias"] = p_head_biases
+
+ # w head
+
+ w_head = np.load(w_head_path)
+
+ w_head_weights = w_head["weights"]
+ w_head_weights = torch.from_numpy(w_head_weights)
+ w_head_weights = w_head_weights.unsqueeze(0)
+
+ w_head_biases = w_head["biases"]
+ w_head_biases = torch.from_numpy(w_head_biases)
+ w_head_biases = w_head_biases.unsqueeze(0)
+
+ state_dict["w_head.weight"] = w_head_weights
+ state_dict["w_head.bias"] = w_head_biases
+
+ # vision model
+
+ vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
+ vision_model_state_dict = vision_model.state_dict()
+
+ for key, value in vision_model_state_dict.items():
+ key = f"vision_model.{key}"
+ state_dict[key] = value
+
+ # full model
+
+ config = CLIPConfig.from_pretrained("openai/clip-vit-large-patch14")
+ safety_checker = IFSafetyChecker(config)
+
+ safety_checker.load_state_dict(state_dict)
+
+ return safety_checker
+
+
+def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config["channel_mult"])
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config["num_head_channels"]
+
+ use_linear_projection = (
+ original_unet_config["use_linear_in_transformer"]
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ projection_class_embeddings_input_dim = None
+
+ if class_embed_type is None:
+ if "num_classes" in original_unet_config:
+ if original_unet_config["num_classes"] == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
+ )
+
+ config = {
+ "sample_size": original_unet_config["image_size"],
+ "in_channels": original_unet_config["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": original_unet_config["num_res_blocks"],
+ "cross_attention_dim": original_unet_config["encoder_channels"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config["out_channels"],
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "addition_embed_type": "text",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config["use_scale_shift_norm"]:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ if "encoder_dim" in original_unet_config:
+ config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
+
+ return config
+
+
+def convert_ldm_unet_checkpoint(unet_state_dict, config, path=None):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] in [None, "identity"]:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ # TODO need better check than i in [4, 8, 12, 16]
+ block_type = config["down_block_types"][block_id]
+ if (block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D") and i in [
+ 4,
+ 8,
+ 12,
+ 16,
+ ]:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.2" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.2", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict.pop("encoder_proj.weight")
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict.pop("encoder_proj.bias")
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ new_checkpoint["add_embedding.norm1.weight"] = unet_state_dict.pop("encoder_pooling.0.weight")
+ new_checkpoint["add_embedding.norm1.bias"] = unet_state_dict.pop("encoder_pooling.0.bias")
+
+ new_checkpoint["add_embedding.pool.positional_embedding"] = unet_state_dict.pop(
+ "encoder_pooling.1.positional_embedding"
+ )
+ new_checkpoint["add_embedding.pool.k_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.k_proj.weight")
+ new_checkpoint["add_embedding.pool.k_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.k_proj.bias")
+ new_checkpoint["add_embedding.pool.q_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.q_proj.weight")
+ new_checkpoint["add_embedding.pool.q_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.q_proj.bias")
+ new_checkpoint["add_embedding.pool.v_proj.weight"] = unet_state_dict.pop("encoder_pooling.1.v_proj.weight")
+ new_checkpoint["add_embedding.pool.v_proj.bias"] = unet_state_dict.pop("encoder_pooling.1.v_proj.bias")
+
+ new_checkpoint["add_embedding.proj.weight"] = unet_state_dict.pop("encoder_pooling.2.weight")
+ new_checkpoint["add_embedding.proj.bias"] = unet_state_dict.pop("encoder_pooling.2.bias")
+
+ new_checkpoint["add_embedding.norm2.weight"] = unet_state_dict.pop("encoder_pooling.3.weight")
+ new_checkpoint["add_embedding.norm2.bias"] = unet_state_dict.pop("encoder_pooling.3.bias")
+
+ return new_checkpoint
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ if "qkv" in new_item:
+ continue
+
+ if "encoder_kv" in new_item:
+ continue
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
+
+ new_item = new_item.replace("norm_encoder.weight", "norm_cross.weight")
+ new_item = new_item.replace("norm_encoder.bias", "norm_cross.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_attention_to_checkpoint(new_checkpoint, unet_state_dict, old_path, new_path, config):
+ qkv_weight = unet_state_dict.pop(f"{old_path}.qkv.weight")
+ qkv_weight = qkv_weight[:, :, 0]
+
+ qkv_bias = unet_state_dict.pop(f"{old_path}.qkv.bias")
+
+ is_cross_attn_only = "only_cross_attention" in config and config["only_cross_attention"]
+
+ split = 1 if is_cross_attn_only else 3
+
+ weights, bias = split_attentions(
+ weight=qkv_weight,
+ bias=qkv_bias,
+ split=split,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ if is_cross_attn_only:
+ query_weight, q_bias = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight[0]
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias[0]
+ else:
+ [query_weight, key_weight, value_weight], [q_bias, k_bias, v_bias] = weights, bias
+ new_checkpoint[f"{new_path}.to_q.weight"] = query_weight
+ new_checkpoint[f"{new_path}.to_q.bias"] = q_bias
+ new_checkpoint[f"{new_path}.to_k.weight"] = key_weight
+ new_checkpoint[f"{new_path}.to_k.bias"] = k_bias
+ new_checkpoint[f"{new_path}.to_v.weight"] = value_weight
+ new_checkpoint[f"{new_path}.to_v.bias"] = v_bias
+
+ encoder_kv_weight = unet_state_dict.pop(f"{old_path}.encoder_kv.weight")
+ encoder_kv_weight = encoder_kv_weight[:, :, 0]
+
+ encoder_kv_bias = unet_state_dict.pop(f"{old_path}.encoder_kv.bias")
+
+ [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions(
+ weight=encoder_kv_weight,
+ bias=encoder_kv_bias,
+ split=2,
+ chunk_size=config["attention_head_dim"],
+ )
+
+ new_checkpoint[f"{new_path}.add_k_proj.weight"] = encoder_k_weight
+ new_checkpoint[f"{new_path}.add_k_proj.bias"] = encoder_k_bias
+ new_checkpoint[f"{new_path}.add_v_proj.weight"] = encoder_v_weight
+ new_checkpoint[f"{new_path}.add_v_proj.bias"] = encoder_v_bias
+
+
+def assign_to_checkpoint(paths, checkpoint, old_checkpoint, additional_replacements=None, config=None):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
+ attention layers, and takes into account additional replacements that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ for path in paths:
+ new_path = path["new"]
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ if "proj_attn.weight" in new_path or "to_out.0.weight" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?)
+def split_attentions(*, weight, bias, split, chunk_size):
+ weights = [None] * split
+ biases = [None] * split
+
+ weights_biases_idx = 0
+
+ for starting_row_index in range(0, weight.shape[0], chunk_size):
+ row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size)
+
+ weight_rows = weight[row_indices, :]
+ bias_rows = bias[row_indices]
+
+ if weights[weights_biases_idx] is None:
+ weights[weights_biases_idx] = weight_rows
+ biases[weights_biases_idx] = bias_rows
+ else:
+ assert weights[weights_biases_idx] is not None
+ weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows])
+ biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows])
+
+ weights_biases_idx = (weights_biases_idx + 1) % split
+
+ return weights, biases
+
+
+def parse_list(value):
+ if isinstance(value, str):
+ value = value.split(",")
+ value = [int(v) for v in value]
+ elif isinstance(value, list):
+ pass
+ else:
+ raise ValueError(f"Can't parse list for type: {type(value)}")
+
+ return value
+
+
+# below is copy and pasted from original convert_if_stage_2.py script
+
+
+def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
+ orig_path = unet_checkpoint_path
+
+ original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
+ original_unet_config = original_unet_config["params"]
+
+ unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
+ unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
+ original_unet_config["channel_mult"].split(",")[-1]
+ )
+ if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
+ unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
+ unet_diffusers_config["class_embed_type"] = "timestep"
+ unet_diffusers_config["addition_embed_type"] = "text"
+
+ unet_diffusers_config["time_embedding_act_fn"] = "gelu"
+ unet_diffusers_config["resnet_skip_time_act"] = True
+ unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
+ unet_diffusers_config["only_cross_attention"] = (
+ bool(original_unet_config["disable_self_attentions"])
+ if (
+ "disable_self_attentions" in original_unet_config
+ and isinstance(original_unet_config["disable_self_attentions"], int)
+ )
+ else True
+ )
+
+ if sample_size is None:
+ unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
+ else:
+ # The second upscaler unet's sample size is incorrectly specified
+ # in the config and is instead hardcoded in source
+ unet_diffusers_config["sample_size"] = sample_size
+
+ unet_checkpoint = torch.load(os.path.join(unet_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
+
+ if verify_param_count:
+ # check that architecture matches - is a bit slow
+ verify_param_count(orig_path, unet_diffusers_config)
+
+ converted_unet_checkpoint = superres_convert_ldm_unet_checkpoint(
+ unet_checkpoint, unet_diffusers_config, path=unet_checkpoint_path
+ )
+ converted_keys = converted_unet_checkpoint.keys()
+
+ model = UNet2DConditionModel(**unet_diffusers_config)
+ expected_weights = model.state_dict().keys()
+
+ diff_c_e = set(converted_keys) - set(expected_weights)
+ diff_e_c = set(expected_weights) - set(converted_keys)
+
+ assert len(diff_e_c) == 0, f"Expected, but not converted: {diff_e_c}"
+ assert len(diff_c_e) == 0, f"Converted, but not expected: {diff_c_e}"
+
+ model.load_state_dict(converted_unet_checkpoint)
+
+ return model
+
+
+def superres_create_unet_diffusers_config(original_unet_config):
+ attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
+ attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
+
+ channel_mult = parse_list(original_unet_config["channel_mult"])
+ block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
+
+ down_block_types = []
+ resolution = 1
+
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnDownBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetDownsampleBlock2D"
+ else:
+ block_type = "DownBlock2D"
+
+ down_block_types.append(block_type)
+
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ if resolution in attention_resolutions:
+ block_type = "SimpleCrossAttnUpBlock2D"
+ elif original_unet_config["resblock_updown"]:
+ block_type = "ResnetUpsampleBlock2D"
+ else:
+ block_type = "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ head_dim = original_unet_config["num_head_channels"]
+ use_linear_projection = (
+ original_unet_config["use_linear_in_transformer"]
+ if "use_linear_in_transformer" in original_unet_config
+ else False
+ )
+ if use_linear_projection:
+ # stable diffusion 2-base-512 and 2-768
+ if head_dim is None:
+ head_dim = [5, 10, 20, 20]
+
+ class_embed_type = None
+ projection_class_embeddings_input_dim = None
+
+ if "num_classes" in original_unet_config:
+ if original_unet_config["num_classes"] == "sequential":
+ class_embed_type = "projection"
+ assert "adm_in_channels" in original_unet_config
+ projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
+ else:
+ raise NotImplementedError(
+ f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
+ )
+
+ config = {
+ "in_channels": original_unet_config["in_channels"],
+ "down_block_types": tuple(down_block_types),
+ "block_out_channels": tuple(block_out_channels),
+ "layers_per_block": tuple(original_unet_config["num_res_blocks"]),
+ "cross_attention_dim": original_unet_config["encoder_channels"],
+ "attention_head_dim": head_dim,
+ "use_linear_projection": use_linear_projection,
+ "class_embed_type": class_embed_type,
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
+ "out_channels": original_unet_config["out_channels"],
+ "up_block_types": tuple(up_block_types),
+ "upcast_attention": False, # TODO: guessing
+ "cross_attention_norm": "group_norm",
+ "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
+ "act_fn": "gelu",
+ }
+
+ if original_unet_config["use_scale_shift_norm"]:
+ config["resnet_time_scale_shift"] = "scale_shift"
+
+ return config
+
+
+def superres_convert_ldm_unet_checkpoint(unet_state_dict, config, path=None, extract_ema=False):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ if config["class_embed_type"] is None:
+ # No parameters to port
+ ...
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["aug_proj.0.weight"]
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["aug_proj.0.bias"]
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["aug_proj.2.weight"]
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["aug_proj.2.bias"]
+ else:
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
+
+ if "encoder_proj.weight" in unet_state_dict:
+ new_checkpoint["encoder_hid_proj.weight"] = unet_state_dict["encoder_proj.weight"]
+ new_checkpoint["encoder_hid_proj.bias"] = unet_state_dict["encoder_proj.bias"]
+
+ if "encoder_pooling.0.weight" in unet_state_dict:
+ mapping = {
+ "encoder_pooling.0": "add_embedding.norm1",
+ "encoder_pooling.1": "add_embedding.pool",
+ "encoder_pooling.2": "add_embedding.proj",
+ "encoder_pooling.3": "add_embedding.norm2",
+ }
+ for key in unet_state_dict.keys():
+ if key.startswith("encoder_pooling"):
+ prefix = key[: len("encoder_pooling.0")]
+ new_key = key.replace(prefix, mapping[prefix])
+ new_checkpoint[new_key] = unet_state_dict[key]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = [e + 1 for e in config["layers_per_block"]]
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+ downsampler_ids = layers_per_block_cumsum
+ else:
+ # TODO need better check than i in [4, 8, 12, 16]
+ downsampler_ids = [4, 8, 12, 16]
+
+ for i in range(1, num_input_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = (i - 1) // (layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if (i - 1) < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = (i - 1) - passed_blocks
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+
+ block_type = config["down_block_types"][block_id]
+ if (
+ block_type == "ResnetDownsampleBlock2D" or block_type == "SimpleCrossAttnDownBlock2D"
+ ) and i in downsampler_ids:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.downsamplers.0"}
+ else:
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ old_path = f"input_blocks.{i}.1"
+ new_path = f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": old_path, "new": new_path}
+ assign_to_checkpoint(
+ paths,
+ new_checkpoint,
+ unet_state_dict,
+ additional_replacements=[meta_path],
+ config=config,
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ old_path = "middle_block.1"
+ new_path = "mid_block.attentions.0"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ if not isinstance(config["layers_per_block"], int):
+ layers_per_block_list = list(reversed([e + 1 for e in config["layers_per_block"]]))
+ layers_per_block_cumsum = list(np.cumsum(layers_per_block_list))
+
+ for i in range(num_output_blocks):
+ if isinstance(config["layers_per_block"], int):
+ layers_per_block = config["layers_per_block"]
+ block_id = i // (layers_per_block + 1)
+ layer_in_block_id = i % (layers_per_block + 1)
+ else:
+ block_id = next(k for k, n in enumerate(layers_per_block_cumsum) if i < n)
+ passed_blocks = layers_per_block_cumsum[block_id - 1] if block_id > 0 else 0
+ layer_in_block_id = i - passed_blocks
+
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ # len(output_block_list) == 1 -> resnet
+ # len(output_block_list) == 2 -> resnet, attention or resnet, upscale resnet
+ # len(output_block_list) == 3 -> resnet, attention, upscale resnet
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+
+ has_attention = True
+ if len(output_block_list) == 2 and any("in_layers" in k for k in output_block_list["1"]):
+ has_attention = False
+
+ maybe_attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+
+ # this layer was no attention
+ has_attention = False
+ maybe_attentions = []
+
+ if has_attention:
+ old_path = f"output_blocks.{i}.1"
+ new_path = f"up_blocks.{block_id}.attentions.{layer_in_block_id}"
+
+ assign_attention_to_checkpoint(
+ new_checkpoint=new_checkpoint,
+ unet_state_dict=unet_state_dict,
+ old_path=old_path,
+ new_path=new_path,
+ config=config,
+ )
+
+ paths = renew_attention_paths(maybe_attentions)
+ meta_path = {
+ "old": old_path,
+ "new": new_path,
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(output_block_list) == 3 or (not has_attention and len(maybe_attentions) > 0):
+ layer_id = len(output_block_list) - 1
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.{layer_id}" in key]
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"output_blocks.{i}.{layer_id}", "new": f"up_blocks.{block_id}.upsamplers.0"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ return new_checkpoint
+
+
+def verify_param_count(orig_path, unet_diffusers_config):
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II = IFStageII(device="cpu", dir_or_name=orig_path)
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
+ else:
+ assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
+
+ unet = UNet2DConditionModel(**unet_diffusers_config)
+
+ # in params
+ assert_param_count(unet.time_embedding, if_II.model.time_embed)
+ assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
+
+ # downblocks
+ assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
+ assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
+ assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
+
+ if "-II-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
+ assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
+ assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
+
+ # mid block
+ assert_param_count(unet.mid_block, if_II.model.middle_block)
+
+ # up block
+ if "-II-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
+ if "-III-" in orig_path:
+ assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
+ assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
+ assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
+ assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
+ assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
+ assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
+
+ # out params
+ assert_param_count(unet.conv_norm_out, if_II.model.out[0])
+ assert_param_count(unet.conv_out, if_II.model.out[2])
+
+ # make sure all model architecture has same param count
+ assert_param_count(unet, if_II.model)
+
+
+def assert_param_count(model_1, model_2):
+ count_1 = sum(p.numel() for p in model_1.parameters())
+ count_2 = sum(p.numel() for p in model_2.parameters())
+ assert count_1 == count_2, f"{model_1.__class__}: {count_1} != {model_2.__class__}: {count_2}"
+
+
+def superres_check_against_original(dump_path, unet_checkpoint_path):
+ model_path = dump_path
+ model = UNet2DConditionModel.from_pretrained(model_path)
+ model.to("cuda")
+ orig_path = unet_checkpoint_path
+
+ if "-II-" in orig_path:
+ from deepfloyd_if.modules import IFStageII
+
+ if_II_model = IFStageII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+ elif "-III-" in orig_path:
+ from deepfloyd_if.modules import IFStageIII
+
+ if_II_model = IFStageIII(device="cuda", dir_or_name=orig_path, model_kwargs={"precision": "fp32"}).model
+
+ batch_size = 1
+ channels = model.config.in_channels // 2
+ height = model.config.sample_size
+ width = model.config.sample_size
+ height = 1024
+ width = 1024
+
+ torch.manual_seed(0)
+
+ latents = torch.randn((batch_size, channels, height, width), device=model.device)
+ image_small = torch.randn((batch_size, channels, height // 4, width // 4), device=model.device)
+
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ latent_model_input = torch.cat([latents, image_upscaled], dim=1).to(model.dtype)
+ t = torch.tensor([5], device=model.device).to(model.dtype)
+
+ seq_len = 64
+ encoder_hidden_states = torch.randn((batch_size, seq_len, model.config.encoder_hid_dim), device=model.device).to(
+ model.dtype
+ )
+
+ fake_class_labels = torch.tensor([t], device=model.device).to(model.dtype)
+
+ with torch.no_grad():
+ out = if_II_model(latent_model_input, t, aug_steps=fake_class_labels, text_emb=encoder_hidden_states)
+
+ if_II_model.to("cpu")
+ del if_II_model
+ import gc
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ print(50 * "=")
+
+ with torch.no_grad():
+ noise_pred = model(
+ sample=latent_model_input,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=fake_class_labels,
+ timestep=t,
+ ).sample
+
+ print("Out shape", noise_pred.shape)
+ print("Diff", (out - noise_pred).abs().sum())
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/diffusers/scripts/convert_k_upscaler_to_diffusers.py b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
new file mode 100755
index 0000000..62abedd
--- /dev/null
+++ b/diffusers/scripts/convert_k_upscaler_to_diffusers.py
@@ -0,0 +1,297 @@
+import argparse
+
+import huggingface_hub
+import k_diffusion as K
+import torch
+
+from diffusers import UNet2DConditionModel
+
+
+UPSCALER_REPO = "pcuenq/k-upscaler"
+
+
+def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
+ rv = {
+ # norm1
+ f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
+ # conv1
+ f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
+ f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
+ # norm2
+ f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
+ f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
+ # conv2
+ f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
+ f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
+ }
+
+ if resnet.conv_shortcut is not None:
+ rv.update(
+ {
+ f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
+ }
+ )
+
+ return rv
+
+
+def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
+ weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
+ bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
+ rv = {
+ # norm
+ f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
+ f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
+ # to_q
+ f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
+ # to_k
+ f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
+ }
+
+ return rv
+
+
+def cross_attn_to_diffusers_checkpoint(
+ checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
+):
+ weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
+ bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
+
+ rv = {
+ # norm2 (ada groupnorm)
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.weight"
+ ],
+ f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
+ f"{attention_prefix}.norm_dec.mapper.bias"
+ ],
+ # layernorm on encoder_hidden_state
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
+ f"{attention_prefix}.norm_enc.weight"
+ ],
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
+ f"{attention_prefix}.norm_enc.bias"
+ ],
+ # to_q
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
+ f"{attention_prefix}.q_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
+ f"{attention_prefix}.q_proj.bias"
+ ],
+ # to_k
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
+ # to_v
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
+ # to_out
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
+ f"{attention_prefix}.out_proj.weight"
+ ]
+ .squeeze(-1)
+ .squeeze(-1),
+ f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
+ f"{attention_prefix}.out_proj.bias"
+ ],
+ }
+
+ return rv
+
+
+def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
+ block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
+ block_prefix = f"{block_prefix}.{block_idx}"
+
+ diffusers_checkpoint = {}
+
+ if not hasattr(block, "attentions"):
+ n = 1 # resnet only
+ elif not block.attentions[0].add_self_attention:
+ n = 2 # resnet -> cross-attention
+ else:
+ n = 3 # resnet -> self-attention -> cross-attention)
+
+ for resnet_idx, resnet in enumerate(block.resnets):
+ # diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
+ diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
+ idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
+ resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
+
+ diffusers_checkpoint.update(
+ resnet_to_diffusers_checkpoint(
+ resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
+ )
+ )
+
+ if hasattr(block, "attentions"):
+ for attention_idx, attention in enumerate(block.attentions):
+ diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
+ idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
+ self_attention_prefix = f"{block_prefix}.{idx}"
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_index = 1 if not attention.add_self_attention else 2
+ idx = (
+ n * attention_idx + cross_attention_index
+ if block_type == "up"
+ else n * attention_idx + cross_attention_index + 1
+ )
+ cross_attention_prefix = f"{block_prefix}.{idx }"
+
+ diffusers_checkpoint.update(
+ cross_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ diffusers_attention_index=2,
+ attention_prefix=cross_attention_prefix,
+ )
+ )
+
+ if attention.add_self_attention is True:
+ diffusers_checkpoint.update(
+ self_attn_to_diffusers_checkpoint(
+ checkpoint,
+ diffusers_attention_prefix=diffusers_attention_prefix,
+ attention_prefix=self_attention_prefix,
+ )
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_to_diffusers_checkpoint(model, checkpoint):
+ diffusers_checkpoint = {}
+
+ # pre-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_in.weight": checkpoint["inner_model.proj_in.weight"],
+ "conv_in.bias": checkpoint["inner_model.proj_in.bias"],
+ }
+ )
+
+ # timestep and class embedding
+ diffusers_checkpoint.update(
+ {
+ "time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
+ "time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
+ "time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
+ "time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
+ "time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
+ "time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
+ }
+ )
+
+ # down_blocks
+ for down_block_idx, down_block in enumerate(model.down_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
+
+ # up_blocks
+ for up_block_idx, up_block in enumerate(model.up_blocks):
+ diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
+
+ # post-processing
+ diffusers_checkpoint.update(
+ {
+ "conv_out.weight": checkpoint["inner_model.proj_out.weight"],
+ "conv_out.bias": checkpoint["inner_model.proj_out.bias"],
+ }
+ )
+
+ return diffusers_checkpoint
+
+
+def unet_model_from_original_config(original_config):
+ in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
+ out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
+
+ block_out_channels = original_config["channels"]
+
+ assert (
+ len(set(original_config["depths"])) == 1
+ ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ layers_per_block = original_config["depths"][0]
+
+ class_labels_dim = original_config["mapping_cond_dim"]
+ cross_attention_dim = original_config["cross_cond_dim"]
+
+ attn1_types = []
+ attn2_types = []
+ for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
+ if s:
+ a1 = "self"
+ a2 = "cross" if c else None
+ elif c:
+ a1 = "cross"
+ a2 = None
+ else:
+ a1 = None
+ a2 = None
+ attn1_types.append(a1)
+ attn2_types.append(a2)
+
+ unet = UNet2DConditionModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
+ mid_block_type=None,
+ up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn="gelu",
+ norm_num_groups=None,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=64,
+ time_cond_proj_dim=class_labels_dim,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="fourier",
+ timestep_post_act="gelu",
+ conv_in_kernel=1,
+ conv_out_kernel=1,
+ )
+
+ return unet
+
+
+def main(args):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
+ orig_weights_path = huggingface_hub.hf_hub_download(
+ UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
+ )
+ print(f"loading original model configuration from {orig_config_path}")
+ print(f"loading original model checkpoint from {orig_weights_path}")
+
+ print("converting to diffusers unet")
+ orig_config = K.config.load_config(open(orig_config_path))["model"]
+ model = unet_model_from_original_config(orig_config)
+
+ orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
+ converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
+
+ model.load_state_dict(converted_checkpoint, strict=True)
+ model.save_pretrained(args.dump_path)
+ print(f"saving converted unet model in {args.dump_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
new file mode 100755
index 0000000..5135eae
--- /dev/null
+++ b/diffusers/scripts/convert_kakao_brain_unclip_to_diffusers.py
@@ -0,0 +1,1159 @@
+import argparse
+import tempfile
+
+import torch
+from accelerate import load_checkpoint_and_dispatch
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
+from diffusers.models.transformers.prior_transformer import PriorTransformer
+from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
+from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
+
+
+r"""
+Example - From the diffusers root directory:
+
+Download weights:
+```sh
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt
+$ wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th
+```
+
+Convert the model:
+```sh
+$ python scripts/convert_kakao_brain_unclip_to_diffusers.py \
+ --decoder_checkpoint_path ./decoder-ckpt-step\=01000000-of-01000000.ckpt \
+ --super_res_unet_checkpoint_path ./improved-sr-ckpt-step\=1.2M.ckpt \
+ --prior_checkpoint_path ./prior-ckpt-step\=01000000-of-01000000.ckpt \
+ --clip_stat_path ./ViT-L-14_stats.th \
+ --dump_path