From 3259814031865c4e8875431037d8dc19d4fc7393 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 6 May 2026 14:45:06 +0800 Subject: [PATCH 01/16] Update Self-Forcing --- examples/wan2.1_self_forcing/predict_t2v.py | 249 ++ scripts/wan2.1_self_forcing/train_distill.py | 2726 +++++++++++++++++ scripts/wan2.1_self_forcing/train_distill.sh | 43 + videox_fun/models/__init__.py | 1 + .../models/wan_transformer3d_self_forcing.py | 1098 +++++++ videox_fun/pipeline/__init__.py | 1 + .../pipeline/pipeline_wan_self_forcing.py | 842 +++++ 7 files changed, 4960 insertions(+) create mode 100644 examples/wan2.1_self_forcing/predict_t2v.py create mode 100644 scripts/wan2.1_self_forcing/train_distill.py create mode 100644 scripts/wan2.1_self_forcing/train_distill.sh create mode 100644 videox_fun/models/wan_transformer3d_self_forcing.py create mode 100644 videox_fun/pipeline/pipeline_wan_self_forcing.py diff --git a/examples/wan2.1_self_forcing/predict_t2v.py b/examples/wan2.1_self_forcing/predict_t2v.py new file mode 100644 index 00000000..483a5dbe --- /dev/null +++ b/examples/wan2.1_self_forcing/predict_t2v.py @@ -0,0 +1,249 @@ +import os +import sys + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from omegaconf import OmegaConf +from PIL import Image + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, + WanTransformer3DModel_SelfForcing) +from videox_fun.pipeline import WanSelfForcingPipeline +from videox_fun.utils import (register_auto_device_hook, + safe_enable_group_offload) +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, + save_videos_grid) + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, model_group_offload, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_group_offload transfers internal layer groups between CPU/CUDA, +# balancing memory efficiency and speed between full-module and leaf-level offloading methods. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "sequential_cpu_offload" +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. +compile_dit = False + +# Config and model path +config_path = "config/wan2.1/wan_civitai.yaml" +# model path +model_name = "models/Diffusion_Transformer/Wan2.1-T2V-1.3B" + +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" +# [NOTE]: Noise schedule shift parameter. Affects temporal dynamics. +# Used when the sampler is in "Flow_Unipc", "Flow_DPM++". +shift = 5 + +# Load pretrained model if need +transformer_path = "models/Diffusion_Transformer/Self-Forcing/checkpoints/self_forcing_dmd.pt" +vae_path = None +lora_path = None + +# Other params +sample_size = [480, 832] +video_length = 81 +fps = 16 + +# Self-Forcing causal inference config +# Number of frames to generate per block (1 for standard causal, higher for faster but more memory) +num_frame_per_block = 3 +# Local attention window size (-1 for global attention) +local_attn_size = -1 +# Others +independent_first_frame = False +context_noise = 0.0 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +prompt = "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about." +negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +guidance_scale = 1.0 +seed = 43 +num_inference_steps = 4 +lora_weight = 0.55 +save_path = "samples/wan-videos-self-forcing-t2v" + +device = set_multi_gpus_devices(1, 1) +config = OmegaConf.load(config_path) + +# Load transformer with causal inference support if enabled +transformer_additional_kwargs = OmegaConf.to_container(config['transformer_additional_kwargs']) +transformer_additional_kwargs['local_attn_size'] = local_attn_size + +transformer = WanTransformer3DModel_SelfForcing.from_pretrained( + os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=transformer_additional_kwargs, + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + + state_dict = state_dict["generator_ema"] if "generator_ema" in state_dict else state_dict + if any(k.startswith("model.") for k in state_dict.keys()): + state_dict = {k.replace("model.", "", 1) if k.startswith("model.") else k: v for k, v in state_dict.items()} + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Vae +vae = AutoencoderKLWan.from_pretrained( + os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), +).to(weight_dtype) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Tokenizer +tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), +) + +# Get Text encoder +text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +# Get Scheduler +Chosen_Scheduler = scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++": + config['scheduler_kwargs']['shift'] = 1 +scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) +) + +# Get Pipeline +pipeline = WanSelfForcingPipeline( + transformer=transformer, + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, +) + +if compile_dit: + for i in range(len(pipeline.transformer.blocks)): + pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(transformer, ["modulation",], device=device) + transformer.freqs = transformer.freqs.to(device=device) + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 + latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 + + sample = pipeline( + prompt, + num_frames = video_length, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + shift = shift, + num_frame_per_block = num_frame_per_block, + independent_first_frame = independent_first_frame, + context_noise = context_noise, + ).videos + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + if video_length == 1: + video_path = os.path.join(save_path, prefix + ".png") + + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(video_path) + else: + video_path = os.path.join(save_path, prefix + ".mp4") + save_videos_grid(sample, video_path, fps=fps) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/scripts/wan2.1_self_forcing/train_distill.py b/scripts/wan2.1_self_forcing/train_distill.py new file mode 100644 index 00000000..ecf065e2 --- /dev/null +++ b/scripts/wan2.1_self_forcing/train_distill.py @@ -0,0 +1,2726 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import gc +import json +import logging +import math +import os +import pickle +import shutil +import sys + +import accelerate +import diffusers +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator, FullyShardedDataParallelPlugin +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, FullStateDictConfig, ShardedOptimStateDictConfig, + ShardedStateDictConfig) +from torch.utils.data import BatchSampler, Dataset, RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data import (ASPECT_RATIO_512, ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + ImageVideoDataset, ImageVideoSampler, + RandomSampler, TextDataset, get_closest_ratio, + get_random_mask) +from videox_fun.models import (AutoencoderKLWan, CLIPModel, WanT5EncoderModel, + WanTransformer3DModel, + WanTransformer3DModel_SelfForcing) +from videox_fun.pipeline import (WanI2VPipeline, WanPipeline, + WanSelfForcingPipeline) +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.utils import (calculate_dimensions, get_image_latent, + get_image_to_video_latent, + save_videos_grid) + +if is_wandb_available(): + import wandb + + +def initialize_kv_cache_for_training(batch_size, num_frames, frame_seq_length, num_layers, num_heads, head_dim, dtype, device): + """Initialize KV cache for block-by-block training""" + kv_cache_size = num_frames * frame_seq_length + kv_cache = [] + + for _ in range(num_layers): + kv_cache.append({ + "k": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "global_end_index": torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": torch.tensor([0], dtype=torch.long, device=device) + }) + + return kv_cache + + +def initialize_crossattn_cache_for_training(batch_size, text_len, num_layers, num_heads, head_dim, dtype, device): + """Initialize cross-attention cache for block-by-block training""" + crossattn_cache = [] + + for _ in range(num_layers): + crossattn_cache.append({ + "k": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "is_init": False + }) + + return crossattn_cache + + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.75 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, transformer3d, args, config, accelerator, weight_dtype, global_step): + try: + is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' + if is_deepspeed: + origin_config = transformer3d.config + transformer3d.config = accelerator.unwrap_model(transformer3d).config + with torch.no_grad(), torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + logger.info("Running validation... ") + scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + if args.train_mode != "normal": + raise NotImplementedError(f"Validation for train_mode '{args.train_mode}' is not yet supported with WanSelfForcingPipeline. Only T2V (train_mode='normal') is currently supported.") + else: + pipeline = WanSelfForcingPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + rank_seed = args.seed + accelerator.process_index + generator = torch.Generator(device=accelerator.device).manual_seed(rank_seed) + logger.info(f"Rank {accelerator.process_index} using seed: {rank_seed}") + + for i in range(len(args.validation_prompts)): + if args.train_mode != "normal": + raise NotImplementedError(f"Validation for train_mode '{args.train_mode}' is not yet supported with WanSelfForcingPipeline. Only T2V (train_mode='normal') is currently supported.") + else: + if args.fix_sample_size is not None: + height, width = args.fix_sample_size + else: + height, width = args.video_sample_size + sample = pipeline( + args.validation_prompts[i], + num_frames = args.video_sample_n_frames, + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + height = height, + width = width, + generator = generator, + guidance_scale = 1.0, + num_inference_steps = len(args.denoising_step_indices_list), + num_frame_per_block = args.num_frame_per_block, + independent_first_frame = args.independent_first_frame, + context_noise = args.context_noise, + ).videos + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + save_videos_grid( + sample, + os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ) + ) + + del pipeline + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if is_deepspeed: + transformer3d.config = origin_config + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error on rank {accelerator.process_index} with info {e}") + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_paths", + type=str, + default=None, + nargs="+", + help=("A set of control videos evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + help=("The negative prompt of cfg distill"), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_critic", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)." + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_frame_crop", action="store_true", help="Whether enable random frame crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--training_with_video_token_length", action="store_true", help="The training stage of the model in training.", + ) + parser.add_argument( + "--auto_tile_batch_size", action="store_true", help="Whether to auto tile batch size.", + ) + parser.add_argument( + "--motion_sub_loss", action="store_true", help="Whether enable motion sub loss." + ) + parser.add_argument( + "--motion_sub_loss_ratio", type=float, default=0.25, help="The ratio of motion sub loss." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--keep_all_node_same_token_length", + action="store_true", + help="Reference of the length token.", + ) + parser.add_argument( + "--token_sample_size", + type=int, + default=512, + help="Sample size of the token.", + ) + parser.add_argument( + "--video_sample_size", + type=int, + default=512, + help="Sample size of the video.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--video_sample_stride", + type=int, + default=4, + help="Sample stride of the video.", + ) + parser.add_argument( + "--video_sample_n_frames", + type=int, + default=17, + help="Num frame of video.", + ) + parser.add_argument( + "--video_repeat", + type=int, + default=0, + help="Num of repeat video.", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help=( + "The config of the model in training." + ), + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--train_mode", + type=str, + default="normal", + help=( + 'The format of training data. Support `"normal"`' + ' (default), `"i2v"`.' + ), + ) + parser.add_argument( + "--gen_update_interval", + type=int, + default=5, + help="The ratio to update transformer3d.", + ) + parser.add_argument( + "--fake_guidance_scale", + type=float, + default=0.0, + help="The cfg scale for fake iscore.", + ) + parser.add_argument( + "--real_guidance_scale", + type=float, + default=6.0, + help="The cfg scale for real score.", + ) + parser.add_argument( + '--denoising_step_indices_list', + nargs='+', + default=[1000, 750, 500, 250], + help="The denoising step list.", + ) + parser.add_argument( + "--num_frame_per_block", + type=int, + default=3, + help="Number of frames per block for Self-Forcing causal training" + ) + parser.add_argument( + "--independent_first_frame", + action="store_true", + help="Whether first frame is independent ([1, N, N, ...] pattern)" + ) + parser.add_argument( + "--use_kv_cache_training", + action="store_true", + help="Use KV cache block-by-block training (matches original Self-Forcing)" + ) + parser.add_argument( + "--context_noise", + type=int, + default=0, + help="Context noise level for KV cache update (matches training config)" + ) + parser.add_argument( + "--use_teacher_forcing", + action="store_true", + help="Enable teacher forcing training (pass clean_x to transformer)" + ) + parser.add_argument( + "--teacher_forcing_prob", + type=float, + default=1.0, + help="Probability of applying teacher forcing per step (1.0 = always)" + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + config = OmegaConf.load(args.config_path) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + accelerator_fake_score_transformer3d = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler( + **filter_kwargs(FlowMatchEulerDiscreteScheduler, OmegaConf.to_container(config['scheduler_kwargs'])) + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKLWan.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), + ) + vae.eval() + # Get Clip Image Encoder + if args.train_mode != "normal": + clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ) + clip_image_encoder = clip_image_encoder.eval() + else: + clip_image_encoder = None + + # Get Transformer + generator_transformer3d = WanTransformer3DModel_SelfForcing.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + ).to(weight_dtype) + real_score_transformer3d = WanTransformer3DModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + ).to(weight_dtype) + fake_score_transformer3d = WanTransformer3DModel.from_pretrained( + os.path.join(args.pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + ).to(weight_dtype) + + # Freeze vae and text_encoder and set generator_transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + generator_transformer3d.requires_grad_(False) + real_score_transformer3d.requires_grad_(False) + fake_score_transformer3d.requires_grad_(False) + if args.train_mode != "normal": + clip_image_encoder.requires_grad_(False) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = generator_transformer3d.load_state_dict(state_dict, strict=False) + m, u = real_score_transformer3d.load_state_dict(state_dict, strict=False) + m, u = fake_score_transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # A good trainable modules is showed below now. + # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position'] + # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position'] + generator_transformer3d.train() + fake_score_transformer3d.train() + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in generator_transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + for name, param in fake_score_transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0 or zero_stage == 3: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + accelerate_state_dict = {k: v.to(dtype=weight_dtype) for k, v in accelerate_state_dict.items()} + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + models[0].save_pretrained(os.path.join(output_dir, "transformer")) + if not args.use_deepspeed: + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = WanTransformer3DModel.from_pretrained( + input_dir, subfolder="transformer" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + accelerator_fake_score_transformer3d.register_save_state_pre_hook(save_model_hook) + accelerator_fake_score_transformer3d.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + generator_transformer3d.enable_gradient_checkpointing() + fake_score_transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except Exception: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, generator_transformer3d.parameters())) + trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in generator_transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + fake_trainable_params = list(filter(lambda p: p.requires_grad, fake_score_transformer3d.parameters())) + fake_trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in fake_score_transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + fake_trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + fake_trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + critic_optimizer = optimizer_cls( + fake_trainable_params_optim, + lr=args.learning_rate_critic, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + critic_optimizer = optimizer_cls( + fake_trainable_params_optim, + lr=args.learning_rate_critic, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + sample_n_frames_bucket_interval = vae.config.temporal_compression_ratio + + if args.fix_sample_size is not None and args.enable_bucket: + args.video_sample_size = max(max(args.fix_sample_size), args.video_sample_size) + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.training_with_video_token_length = False + args.random_hw_adapt = False + + # Get the dataset + if args.train_mode != "normal": + train_dataset = ImageVideoDataset( + args.train_data_meta, args.train_data_dir, + video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, + video_repeat=args.video_repeat, + image_sample_size=args.image_sample_size, + enable_bucket=args.enable_bucket, enable_inpaint=True if args.train_mode != "normal" else False, + ) + else: + train_dataset = TextDataset( + args.train_data_meta + ) + + def get_length_to_frame_num(token_length): + if args.image_sample_size > args.video_sample_size: + sample_sizes = list(range(args.video_sample_size, args.image_sample_size + 1, 128)) + + if sample_sizes[-1] != args.image_sample_size: + sample_sizes.append(args.image_sample_size) + else: + sample_sizes = [args.image_sample_size] + + length_to_frame_num = { + sample_size: min(token_length / sample_size / sample_size, args.video_sample_n_frames) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 for sample_size in sample_sizes + } + + return length_to_frame_num + + if args.enable_bucket and args.train_mode != "normal": + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + # Create new output + new_examples = {} + new_examples["target_token_length"] = target_token_length + new_examples["pixel_values"] = [] + new_examples["text"] = [] + # Used in Inpaint mode + if args.train_mode != "normal": + new_examples["mask_pixel_values"] = [] + new_examples["mask"] = [] + new_examples["clip_pixel_values"] = [] + + # Get downsample ratio in image and videos + pixel_value = examples[0]["pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + if data_type == 'image': + random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size, image_ratio=[args.image_sample_size / args.video_sample_size], rng=rng) + + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + if args.random_hw_adapt: + if args.training_with_video_token_length: + local_min_size = np.min(np.array([np.mean(np.array([np.shape(example["pixel_values"])[1], np.shape(example["pixel_values"])[2]])) for example in examples])) + # The video will be resized to a lower resolution than its own. + choice_list = [length for length in list(length_to_frame_num.keys()) if length < local_min_size * 1.25] + if len(choice_list) == 0: + choice_list = list(length_to_frame_num.keys()) + if rng is None: + local_video_sample_size = np.random.choice(choice_list) + else: + local_video_sample_size = rng.choice(choice_list) + batch_video_length = length_to_frame_num[local_video_sample_size] + random_downsample_ratio = args.video_sample_size / local_video_sample_size + else: + random_downsample_ratio = get_random_downsample_ratio( + args.video_sample_size, rng=rng) + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + else: + random_downsample_ratio = 1 + batch_video_length = args.video_sample_n_frames + sample_n_frames_bucket_interval + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.video_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + min_example_length = min( + [example["pixel_values"].shape[0] for example in examples] + ) + batch_video_length = int(min(batch_video_length, min_example_length)) + + # Magvae needs the number of frames to be 4n + 1. + batch_video_length = (batch_video_length - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + + if batch_video_length <= 0: + batch_video_length = 1 + + for example in examples: + if args.fix_sample_size is not None: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + new_examples["pixel_values"].append(transform(pixel_values)[:batch_video_length]) + new_examples["text"].append(example["text"]) + + if args.train_mode != "normal": + mask = get_random_mask(new_examples["pixel_values"][-1].size(), image_start_only=True) + mask_pixel_values = new_examples["pixel_values"][-1] * (1 - mask) + # Wan 2.1 use 0 for masked pixels + # + torch.ones_like(new_examples["pixel_values"][-1]) * -1 * mask + new_examples["mask_pixel_values"].append(mask_pixel_values) + new_examples["mask"].append(mask) + + clip_pixel_values = new_examples["pixel_values"][-1][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + new_examples["clip_pixel_values"].append(clip_pixel_values) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + if args.train_mode != "normal": + new_examples["mask_pixel_values"] = torch.stack([example for example in new_examples["mask_pixel_values"]]) + new_examples["mask"] = torch.stack([example for example in new_examples["mask"]]) + new_examples["clip_pixel_values"] = torch.stack([example for example in new_examples["clip_pixel_values"]]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=prompt_attention_mask.to("cpu"))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = prompt_embeds + + neg_txt = [ + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" for text in batch['text'] + ] + neg_prompt_ids = tokenizer( + neg_txt, + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + neg_text_input_ids = neg_prompt_ids.input_ids + neg_prompt_attention_mask = neg_prompt_ids.attention_mask + + neg_seq_lens = neg_prompt_attention_mask.gt(0).sum(dim=1).long() + neg_prompt_embeds = text_encoder(neg_text_input_ids.to("cpu"), attention_mask=neg_prompt_attention_mask.to("cpu"))[0] + neg_prompt_embeds = [u[:v] for u, v in zip(neg_prompt_embeds, neg_seq_lens)] + + new_examples['neg_encoder_attention_mask'] = neg_prompt_ids.attention_mask + new_examples['neg_encoder_hidden_states'] = neg_prompt_embeds + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + ) + elif args.train_mode == "normal": + def collate_fn(examples): + new_examples = {} + new_examples["text"] = [] + for example in examples: + new_examples["text"].append(example["text"]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_ids = tokenizer( + new_examples['text'], + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to("cpu"), attention_mask=prompt_attention_mask.to("cpu"))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + new_examples['encoder_attention_mask'] = prompt_ids.attention_mask + new_examples['encoder_hidden_states'] = prompt_embeds + + neg_txt = [ + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" for text in batch['text'] + ] + neg_prompt_ids = tokenizer( + neg_txt, + max_length=args.tokenizer_max_length, + padding="max_length", + add_special_tokens=True, + truncation=True, + return_tensors="pt" + ) + neg_text_input_ids = neg_prompt_ids.input_ids + neg_prompt_attention_mask = neg_prompt_ids.attention_mask + + neg_seq_lens = neg_prompt_attention_mask.gt(0).sum(dim=1).long() + neg_prompt_embeds = text_encoder(neg_text_input_ids.to("cpu"), attention_mask=neg_prompt_attention_mask.to("cpu"))[0] + neg_prompt_embeds = [u[:v] for u, v in zip(neg_prompt_embeds, neg_seq_lens)] + + new_examples['neg_encoder_attention_mask'] = neg_prompt_ids.attention_mask + new_examples['neg_encoder_hidden_states'] = neg_prompt_embeds + + return new_examples + + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = BatchSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), batch_size=args.train_batch_size, drop_last=True) + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + fake_score_lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + generator_transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + generator_transformer3d, optimizer, train_dataloader, lr_scheduler + ) + fake_score_transformer3d, critic_optimizer, fake_score_lr_scheduler= accelerator_fake_score_transformer3d.prepare( + fake_score_transformer3d, critic_optimizer, fake_score_lr_scheduler + ) + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + real_score_transformer3d = shard_fn(real_score_transformer3d) + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype) + text_encoder = shard_fn(text_encoder) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + real_score_transformer3d.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu") + if args.train_mode != "normal": + clip_image_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + fake_score_path = os.path.join(path, "fake_score") + accelerator.load_state(os.path.join(args.output_dir, path)) + accelerator_fake_score_transformer3d.load_state(os.path.join(args.output_dir, fake_score_path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream and args.train_mode != "normal": + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + denoising_step_list = noise_scheduler.timesteps[args.train_sampling_steps - torch.tensor(args.denoising_step_indices_list)] + + for epoch in range(first_epoch, args.num_train_epochs): + train_dmd_loss = 0.0 + train_denoising_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if args.train_mode != "normal" and epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.mp4", rescale=True) + + clip_pixel_values, mask_pixel_values, texts = batch['clip_pixel_values'].cpu(), batch['mask_pixel_values'].cpu(), batch['text'] + mask_pixel_values = rearrange(mask_pixel_values, "b f c h w -> b c f h w") + for idx, (clip_pixel_value, pixel_value, text) in enumerate(zip(clip_pixel_values, mask_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + Image.fromarray(np.uint8(clip_pixel_value)).save(f"{args.output_dir}/sanity_check/clip_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.png") + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.mp4", rescale=True) + + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + if args.train_mode != "normal": + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (4, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (4, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (4, 1)) + batch['neg_encoder_hidden_states'] = torch.tile(batch['neg_encoder_hidden_states'], (4, 1, 1)) + batch['neg_encoder_attention_mask'] = torch.tile(batch['neg_encoder_attention_mask'], (4, 1)) + else: + batch['text'] = batch['text'] * 4 + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + pixel_values = torch.tile(pixel_values, (2, 1, 1, 1, 1)) + if args.enable_text_encoder_in_dataloader: + batch['encoder_hidden_states'] = torch.tile(batch['encoder_hidden_states'], (2, 1, 1)) + batch['encoder_attention_mask'] = torch.tile(batch['encoder_attention_mask'], (2, 1)) + batch['neg_encoder_hidden_states'] = torch.tile(batch['neg_encoder_hidden_states'], (2, 1, 1)) + batch['neg_encoder_attention_mask'] = torch.tile(batch['neg_encoder_attention_mask'], (2, 1)) + else: + batch['text'] = batch['text'] * 2 + + clip_pixel_values = batch["clip_pixel_values"].to(weight_dtype) + mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype) + mask = batch["mask"].to(weight_dtype) + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) + mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1)) + mask = torch.tile(mask, (4, 1, 1, 1, 1)) + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) + mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1)) + mask = torch.tile(mask, (2, 1, 1, 1, 1)) + + if args.random_frame_crop: + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + last_element = 0.90 + remaining_sum = 1.0 - last_element + other_elements_value = remaining_sum / (length - 1) + special_list = [other_elements_value] * (length - 1) + [last_element] + return special_list + select_frames = [_tmp for _tmp in list(range(sample_n_frames_bucket_interval + 1, args.video_sample_n_frames + sample_n_frames_bucket_interval, sample_n_frames_bucket_interval))] + select_frames_prob = np.array(_create_special_list(len(select_frames))) + + if len(select_frames) != 0: + if rng is None: + temp_n_frames = np.random.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = rng.choice(select_frames, p = select_frames_prob) + else: + temp_n_frames = 1 + + # Magvae needs the number of frames to be 4n + 1. + temp_n_frames = (temp_n_frames - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :temp_n_frames, :, :] + mask_pixel_values = mask_pixel_values[:, :temp_n_frames, :, :] + mask = mask[:, :temp_n_frames, :, :] + + # Keep all node same token length to accelerate the traning when resolution grows. + if args.keep_all_node_same_token_length: + if args.token_sample_size > 256: + numbers_list = list(range(256, args.token_sample_size + 1, 128)) + + if numbers_list[-1] != args.token_sample_size: + numbers_list.append(args.token_sample_size) + else: + numbers_list = [256] + numbers_list = [_number * _number * args.video_sample_n_frames for _number in numbers_list] + + actual_token_length = index_rng.choice(numbers_list) + actual_video_length = (min( + actual_token_length / pixel_values.size()[-1] / pixel_values.size()[-2], args.video_sample_n_frames + ) - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + actual_video_length = int(max(actual_video_length, 1)) + + # Magvae needs the number of frames to be 4n + 1. + actual_video_length = (actual_video_length - 1) // sample_n_frames_bucket_interval + 1 + + pixel_values = pixel_values[:, :actual_video_length, :, :] + mask_pixel_values = mask_pixel_values[:, :actual_video_length, :, :] + mask = mask[:, :actual_video_length, :, :] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + clip_image_encoder.to(accelerator.device) + real_score_transformer3d = real_score_transformer3d.to("cpu") + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + + # Encode inpaint latents. + mask_latents = _batch_encode_vae(mask_pixel_values) + if vae_stream_2 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_2) + + # Encode clean latents for teacher forcing + clean_latents = None + if args.use_teacher_forcing: + clean_latents = _batch_encode_vae(pixel_values) + + mask = rearrange(mask, "b f c h w -> b c f h w") + mask = torch.concat( + [ + torch.repeat_interleave(mask[:, :, 0:1], repeats=4, dim=2), + mask[:, :, 1:] + ], dim=2 + ) + mask = mask.view(mask.shape[0], mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]) + mask = mask.transpose(1, 2) + mask = resize_mask(1 - mask, mask_latents) + + inpaint_latents = torch.concat([mask, mask_latents], dim=1) + + clip_context = [] + for clip_pixel_value in clip_pixel_values: + clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy())) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(clip_image_encoder.device, weight_dtype) + _clip_context = clip_image_encoder([clip_image[:, None, :, :]]) + clip_context.append(_clip_context) + clip_context = torch.cat(clip_context) + + target_shape = mask_latents.size() + else: + text = batch['text'] + if args.fix_sample_size is not None: + local_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + num_frames = args.video_sample_n_frames + else: + if args.random_hw_adapt and args.training_with_video_token_length: + # Get token length + target_token_length = args.video_sample_n_frames * args.token_sample_size * args.token_sample_size + length_to_frame_num = get_length_to_frame_num(target_token_length) + + if rng is None: + local_length = np.random.choice(list(length_to_frame_num.keys())) + else: + local_length = rng.choice(list(length_to_frame_num.keys())) + num_frames = length_to_frame_num[local_length] + + aspect_ratio_sample_size = {key : [x / 512 * local_length for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + if rng is None: + aspect_ratio_key = np.random.choice(list(aspect_ratio_sample_size.keys())) + else: + aspect_ratio_key = rng.choice(list(aspect_ratio_sample_size.keys())) + local_sample_size = aspect_ratio_sample_size[aspect_ratio_key] + else: + num_frames = args.video_sample_n_frames + + aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + if rng is None: + aspect_ratio_key = np.random.choice(list(aspect_ratio_sample_size.keys())) + else: + aspect_ratio_key = rng.choice(list(aspect_ratio_sample_size.keys())) + local_sample_size = aspect_ratio_sample_size[aspect_ratio_key] + local_sample_size = [int(x / 16) * 16 for x in local_sample_size] + + target_shape = ( + len(text), + vae.latent_channels, + int((num_frames - 1) // vae.temporal_compression_ratio + 1), + int(local_sample_size[0] // vae.spatial_compression_ratio), + int(local_sample_size[1] // vae.spatial_compression_ratio), + ) + + # Encode clean latents for teacher forcing in T2V mode + clean_latents = None + if args.use_teacher_forcing: + with torch.no_grad(): + pixel_values = batch["pixel_values"].to(weight_dtype) + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + clean_latents = torch.cat(new_pixel_values, dim = 0) + + if args.low_vram: + vae.to('cpu') + real_score_transformer3d = real_score_transformer3d.to("cpu") + if args.train_mode != "normal": + clip_image_encoder.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=accelerator.device) + neg_prompt_embeds = batch['neg_encoder_hidden_states'].to(device=accelerator.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(accelerator.device), attention_mask=prompt_attention_mask.to(accelerator.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + neg_txt = [ + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" for text in batch['text'] + ] + neg_prompt_ids = tokenizer( + neg_txt, + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + neg_text_input_ids = neg_prompt_ids.input_ids + neg_prompt_attention_mask = neg_prompt_ids.attention_mask + + neg_seq_lens = neg_prompt_attention_mask.gt(0).sum(dim=1).long() + neg_prompt_embeds = text_encoder(neg_text_input_ids.to(accelerator.device), attention_mask=neg_prompt_attention_mask.to(accelerator.device))[0] + neg_prompt_embeds = [u[:v] for u, v in zip(neg_prompt_embeds, neg_seq_lens)] + + if args.low_vram: + generator_transformer3d = generator_transformer3d.to(accelerator.device) + real_score_transformer3d = real_score_transformer3d.to(accelerator.device) + fake_score_transformer3d = fake_score_transformer3d.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + with accelerator.accumulate(generator_transformer3d): + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + + step_indices = [ + torch.argmin(torch.abs(schedule_timesteps - t)).item() + for t in timesteps + ] + step_indices = torch.tensor(step_indices, device=accelerator.device) + sigma = sigmas[step_indices].flatten() + + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def add_noise(latents, noise, timesteps): + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + return (1.0 - sigmas) * latents + sigmas * noise + + def generate_and_sync_list(num_denoising_steps, device): + indices = torch.randint(low=0, high=num_denoising_steps, size=(1,), generator=torch_rng, device=device) + if dist.is_initialized(): + dist.broadcast(indices, src=0) + return indices.tolist() + + def convert_flow_pred_to_x0( + scheduler, + flow_pred: torch.Tensor, + xt: torch.Tensor, + timestep: torch.Tensor + ) -> torch.Tensor: + """ + Convert flow matching's prediction to x0 prediction. + Supports both 4D [B, C, H, W] and 5D [B, C, F, H, W] inputs. + """ + original_dtype = flow_pred.dtype + device = flow_pred.device + + flow_pred = flow_pred.double() + xt = xt.double() + timesteps = scheduler.timesteps.to(device).double() + sigmas = scheduler.sigmas.to(device).double() + timestep = timestep.to(device).double() + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id] + + ndim = flow_pred.ndim + if ndim == 4: + sigma_t = sigma_t.view(-1, 1, 1, 1) + elif ndim == 5: + sigma_t = sigma_t.view(-1, 1, 1, 1, 1) + else: + raise ValueError(f"Expected 4D or 5D input, got {ndim}D tensor.") + + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + # --- Main Training Logic --- + bsz, channel, num_frames, height, width = target_shape + if step % args.gen_update_interval == 0: + # Self-Forcing training: create block mask for causal training + patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] + + # frame_seqlen: tokens per frame AFTER VAE compression and patching + # VAE compresses 8x, then patches are extracted with patch_size + frame_seqlen = (height * width) // (patch_h * patch_w) + + # Create block mask if not exists or parameters changed + accelerator.unwrap_model(generator_transformer3d).create_block_mask_for_training( + num_frames=num_frames, + frame_seqlen=frame_seqlen, + num_frame_per_block=args.num_frame_per_block, + independent_first_frame=args.independent_first_frame, + device=accelerator.device + ) + + if args.use_kv_cache_training: + # === KV cache block-by-block training (original Self-Forcing) === + + # Calculate frame_seq_length + frame_seq_length = (target_shape[3] * target_shape[4]) // (patch_h * patch_w) + + # Determine block structure + if not args.independent_first_frame: + assert num_frames % args.num_frame_per_block == 0 + num_blocks = num_frames // args.num_frame_per_block + else: + assert (num_frames - 1) % args.num_frame_per_block == 0 + num_blocks = (num_frames - 1) // args.num_frame_per_block + + all_num_frames = [args.num_frame_per_block] * num_blocks + if args.independent_first_frame: + all_num_frames = [1] + all_num_frames + + # Initialize KV cache + num_layers = generator_transformer3d.config.num_layers + num_heads = generator_transformer3d.config.num_heads + head_dim = generator_transformer3d.config.dim // num_heads + text_len = 512 # T5 sequence length + + kv_cache = initialize_kv_cache_for_training( + batch_size=bsz, + num_frames=num_frames, + frame_seq_length=frame_seq_length, + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + dtype=weight_dtype, + device=accelerator.device + ) + + crossattn_cache = initialize_crossattn_cache_for_training( + batch_size=bsz, + text_len=text_len, + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + dtype=weight_dtype, + device=accelerator.device + ) + + # Block-by-block generation + generator_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) + current_start_frame = 0 + num_input_frames = 0 # T2V mode + + # Use actual batch size from generator_noise (may differ due to SP) + actual_bsz = generator_noise.shape[0] + output_pred = torch.zeros_like(generator_noise) + + # Decide whether to use teacher forcing for this video (once per video, not per block) + use_teacher_forcing_step = ( + args.use_teacher_forcing and + torch.rand(1, generator=torch_rng).item() < args.teacher_forcing_prob + ) + + # Prepare clean_x and aug_t for teacher forcing + clean_x = None + aug_t = None + if use_teacher_forcing_step and clean_latents is not None: + aug_t = torch.zeros(bsz, device=accelerator.device, dtype=torch.int64) + + for block_idx, current_num_frames in enumerate(all_num_frames): + # Extract noise for current block + start_idx = current_start_frame - num_input_frames + end_idx = start_idx + current_num_frames + noisy_input = generator_noise[:, :, start_idx:end_idx] + + # Extract clean latents for current block if using teacher forcing + if use_teacher_forcing_step and clean_latents is not None: + clean_x_block = clean_latents[:, :, start_idx:end_idx] + clean_x = [clean_x_block[i] for i in range(bsz)] + else: + clean_x = None + + # Denoise loop for current block + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=noisy_input.device)[0] + + for index, current_timestep in enumerate(denoising_step_list): + is_final_step = (index == final_step_index) + timestep = torch.full( + [bsz, current_num_frames], + current_timestep, + device=noisy_input.device, + dtype=torch.int64 + ) + + context_manager = torch.no_grad() if not is_final_step else contextlib.nullcontext() + + with context_manager: + # Convert noisy_input to list format + noisy_input_list = [noisy_input[i] for i in range(bsz)] + + # Use full seq_len (consistent with inference code) + full_seq_len = frame_seqlen * num_frames + + generator_pred_block = generator_transformer3d( + x=noisy_input_list, + context=prompt_embeds, + t=timestep, + seq_len=full_seq_len, + kv_cache=kv_cache, + crossattn_cache=crossattn_cache, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + clean_x=clean_x, + aug_t=aug_t, + ) + + # Stack list output to tensor: [B, C, F, H, W] + if isinstance(generator_pred_block, list): + generator_pred_block = torch.stack(generator_pred_block, dim=0) + + # Flatten timestep for convert_flow_pred_to_x0: [B, F] -> [B*F] + generator_pred_block = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=generator_pred_block, + xt=noisy_input, + timestep=timestep[:, 0] + ) + + if is_final_step: + break + + # Add noise for next step + next_timestep = denoising_step_list[index + 1] * torch.ones( + bsz, dtype=torch.long, device=noisy_input.device + ) + noisy_input = add_noise( + generator_pred_block, + torch.randn(generator_pred_block.shape, dtype=generator_pred_block.dtype, device=generator_pred_block.device, generator=torch_rng), + next_timestep + ) + + # Record output + output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = generator_pred_block + + # Update KV cache with context noise + if block_idx < len(all_num_frames) - 1: + context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise + + # Add context noise + generator_pred_block_noisy = add_noise( + generator_pred_block, + torch.randn(generator_pred_block.shape, dtype=generator_pred_block.dtype, device=generator_pred_block.device, generator=torch_rng), + context_timestep[:, 0] + ) + + generator_pred_block_noisy_list = [generator_pred_block_noisy[i] for i in range(bsz)] + + # Use full seq_len (consistent with inference code) + full_seq_len = frame_seqlen * num_frames + + with torch.no_grad(): + generator_transformer3d( + x=generator_pred_block_noisy_list, + context=prompt_embeds, + t=context_timestep, + seq_len=full_seq_len, + kv_cache=kv_cache, + crossattn_cache=crossattn_cache, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + + current_start_frame += current_num_frames + + # Final output + generator_pred = output_pred + seq_len = frame_seqlen * num_frames # For fake/real score computation + + else: + # === Original block mask training === + # Standard backward simulation training + generator_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=generator_noise.device)[0] + + # Precompute seq_len once (same for all steps) + seq_len = frame_seqlen * num_frames + + # Decide whether to use teacher forcing for this step + use_teacher_forcing_step = ( + args.use_teacher_forcing and + torch.rand(1, generator=torch_rng).item() < args.teacher_forcing_prob + ) + + # Prepare clean_x and aug_t for teacher forcing + clean_x = None + aug_t = None + if use_teacher_forcing_step and clean_latents is not None: + clean_x = [clean_latents[i] for i in range(clean_latents.size(0))] + aug_t = torch.zeros(bsz, device=accelerator.device, dtype=torch.int64) + + for index, current_timestep in enumerate(denoising_step_list): + is_final_step = (index == final_step_index) + timestep = torch.full( + generator_noise.shape[:1], + current_timestep, + device=generator_noise.device, + dtype=torch.int64 + ) + + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + context_manager = torch.no_grad() if not is_final_step else contextlib.nullcontext() + + with context_manager: + # Use block_mask for causal training (一次性处理整个视频) + generator_pred = generator_transformer3d( + x=generator_noise, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + clean_x=clean_x, + aug_t=aug_t, + ) + generator_pred = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=generator_pred, + xt=generator_noise, + timestep=timestep + ) + + if is_final_step: + break + + next_timestep = denoising_step_list[index + 1] * torch.ones( + generator_noise.shape[:1], dtype=torch.long, device=generator_noise.device + ) + generator_noise = add_noise( + generator_pred, + torch.randn(generator_pred.shape, dtype=generator_pred.dtype, device=generator_pred.device, generator=torch_rng), + next_timestep + ) + + # Common code for both KV cache and block mask training + indices = idx_sampling(bsz, generator=torch_rng, device=accelerator.device).long().cpu() + generator_timestep = noise_scheduler.timesteps[indices].to(device=accelerator.device) + generator_denoised_input = add_noise( + generator_pred, + torch.randn(generator_pred.shape, dtype=generator_pred.dtype, device=generator_pred.device, generator=torch_rng), + generator_timestep + ).detach().to(accelerator.device, dtype=weight_dtype) + + # Compute fake score + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device), torch.no_grad(): + fake_score_main_cond = fake_score_transformer3d( + x=generator_denoised_input, + context=prompt_embeds, + t=generator_timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + fake_score_main_cond = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=fake_score_main_cond, + xt=generator_denoised_input, + timestep=generator_timestep + ) + + if args.fake_guidance_scale != 0.0: + fake_score_main_uncond = fake_score_transformer3d( + x=generator_denoised_input, + context=neg_prompt_embeds, + t=generator_timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + fake_score_main_uncond = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=fake_score_main_uncond, + xt=generator_denoised_input, + timestep=generator_timestep + ) + fake_score_main = fake_score_main_uncond + ( + fake_score_main_cond - fake_score_main_uncond + ) * args.fake_guidance_scale + else: + fake_score_main = fake_score_main_cond + + # Compute real score + real_score_main_cond = real_score_transformer3d( + x=generator_denoised_input, + context=prompt_embeds, + t=generator_timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + real_score_main_cond = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=real_score_main_cond, + xt=generator_denoised_input, + timestep=generator_timestep + ) + + real_score_main_uncond = real_score_transformer3d( + x=generator_denoised_input, + context=neg_prompt_embeds, + t=generator_timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + real_score_main_uncond = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=real_score_main_uncond, + xt=generator_denoised_input, + timestep=generator_timestep + ) + + real_score_main = real_score_main_uncond + ( + real_score_main_cond - real_score_main_uncond + ) * args.real_guidance_scale + + # DMD loss + fake_to_real_grad = fake_score_main - real_score_main + generator_to_real_norm = generator_pred - real_score_main + normalizer = torch.abs(generator_to_real_norm).mean(dim=[1, 2, 3, 4], keepdim=True) + fake_to_real_grad = fake_to_real_grad / normalizer + fake_to_real_grad = torch.nan_to_num(fake_to_real_grad) + + dmd_loss = 0.5 * F.mse_loss( + generator_pred.double(), + (generator_pred.double() - fake_to_real_grad.double()).detach(), + reduction="mean" + ) + + avg_dmd_loss = accelerator.gather(dmd_loss.repeat(args.train_batch_size)).mean() + train_dmd_loss += avg_dmd_loss.item() / args.gradient_accumulation_steps + + if args.low_vram: + real_score_transformer3d = real_score_transformer3d.to("cpu") + fake_score_transformer3d = fake_score_transformer3d.to("cpu") + torch.cuda.empty_cache() + + accelerator.backward(dmd_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if args.low_vram: + fake_score_transformer3d = fake_score_transformer3d.to(accelerator.device) + torch.cuda.empty_cache() + + with accelerator_fake_score_transformer3d.accumulate(fake_score_transformer3d): + # --- Fake Critic Denoising Loss --- + + if args.use_kv_cache_training: + # KV cache mode: block-by-block generation + fake_score_critic_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) + + # Calculate frame_seq_length + frame_seq_length = (target_shape[3] * target_shape[4]) // (patch_h * patch_w) + + # Determine block structure + if not args.independent_first_frame: + num_blocks = num_frames // args.num_frame_per_block + else: + num_blocks = (num_frames - 1) // args.num_frame_per_block + + all_num_frames = [args.num_frame_per_block] * num_blocks + if args.independent_first_frame: + all_num_frames = [1] + all_num_frames + + # Initialize KV cache + num_layers = generator_transformer3d.config.num_layers + num_heads = generator_transformer3d.config.num_heads + head_dim = generator_transformer3d.config.dim // num_heads + text_len = 512 + + critic_kv_cache = initialize_kv_cache_for_training( + batch_size=bsz, + num_frames=num_frames, + frame_seq_length=frame_seq_length, + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + dtype=weight_dtype, + device=accelerator.device + ) + + critic_crossattn_cache = initialize_crossattn_cache_for_training( + batch_size=bsz, + text_len=text_len, + num_layers=num_layers, + num_heads=num_heads, + head_dim=head_dim, + dtype=weight_dtype, + device=accelerator.device + ) + + current_start_frame = 0 + num_input_frames = 0 + output_pred = torch.zeros_like(fake_score_critic_noise) + + for block_idx, current_num_frames in enumerate(all_num_frames): + start_idx = current_start_frame - num_input_frames + end_idx = start_idx + current_num_frames + noisy_input = fake_score_critic_noise[:, :, start_idx:end_idx] + + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=noisy_input.device)[0] + + for index, current_timestep in enumerate(denoising_step_list): + is_final_step = (index == final_step_index) + timestep = torch.full( + [bsz, current_num_frames], + current_timestep, + device=noisy_input.device, + dtype=torch.int64 + ) + + context_manager = torch.no_grad() + + with context_manager: + noisy_input_list = [noisy_input[i] for i in range(bsz)] + + # Use full seq_len (consistent with inference code) + full_seq_len = frame_seqlen * num_frames + + fake_score_denoised_pred_block = generator_transformer3d( + x=noisy_input_list, + context=prompt_embeds, + t=timestep, + seq_len=full_seq_len, + kv_cache=critic_kv_cache, + crossattn_cache=critic_crossattn_cache, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + + # Stack list output to tensor: [B, C, F, H, W] + if isinstance(fake_score_denoised_pred_block, list): + fake_score_denoised_pred_block = torch.stack(fake_score_denoised_pred_block, dim=0) + + fake_score_denoised_pred_block = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=fake_score_denoised_pred_block, + xt=noisy_input, + timestep=timestep[:, 0] + ) + + if is_final_step: + break + + next_timestep = denoising_step_list[index + 1] * torch.ones( + bsz, dtype=torch.long, device=noisy_input.device + ) + noisy_input = add_noise( + fake_score_denoised_pred_block, + torch.randn(fake_score_denoised_pred_block.shape, dtype=fake_score_denoised_pred_block.dtype, device=fake_score_denoised_pred_block.device, generator=torch_rng), + next_timestep + ) + + output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = fake_score_denoised_pred_block + + # Update KV cache + if block_idx < len(all_num_frames) - 1: + context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise + + fake_score_denoised_pred_noisy = add_noise( + fake_score_denoised_pred_block, + torch.randn(fake_score_denoised_pred_block.shape, dtype=fake_score_denoised_pred_block.dtype, device=fake_score_denoised_pred_block.device, generator=torch_rng), + context_timestep[:, 0] + ) + fake_score_denoised_pred_noisy_list = [fake_score_denoised_pred_noisy[i] for i in range(bsz)] + + # Use full seq_len (consistent with inference code) + full_seq_len = frame_seqlen * num_frames + + with torch.no_grad(): + generator_transformer3d( + x=fake_score_denoised_pred_noisy_list, + context=prompt_embeds, + t=context_timestep, + seq_len=full_seq_len, + kv_cache=critic_kv_cache, + crossattn_cache=critic_crossattn_cache, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + + current_start_frame += current_num_frames + + fake_score_denoised_pred = output_pred + seq_len = frame_seq_length * num_frames + + else: + # Original block mask mode + with torch.no_grad(): + fake_score_critic_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=fake_score_critic_noise.device)[0] + + patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] + seq_len = math.ceil((width * height) / (patch_h * patch_w) * num_frames) + + for index, current_timestep in enumerate(denoising_step_list): + is_final_step = (index == final_step_index) + timestep = torch.full( + fake_score_critic_noise.shape[:1], + current_timestep, + device=fake_score_critic_noise.device, + dtype=torch.int64 + ) + + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + fake_score_denoised_pred = generator_transformer3d( + x=fake_score_critic_noise, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + clean_x=None, + aug_t=None, + ) + fake_score_denoised_pred = convert_flow_pred_to_x0( + scheduler=noise_scheduler, + flow_pred=fake_score_denoised_pred, + xt=fake_score_critic_noise, + timestep=timestep + ) + + if is_final_step: + break + + next_timestep = denoising_step_list[index + 1] * torch.ones( + fake_score_critic_noise.shape[:1], + dtype=torch.long, + device=fake_score_critic_noise.device + ) + + fake_score_critic_noise = add_noise( + fake_score_denoised_pred, + torch.randn(fake_score_denoised_pred.shape, dtype=fake_score_denoised_pred.dtype, device=fake_score_denoised_pred.device, generator=torch_rng), + next_timestep + ) + + indices = idx_sampling(bsz, generator=torch_rng, device=accelerator.device).long().cpu() + critic_timestep = noise_scheduler.timesteps[indices].to(device=accelerator.device) + critic_noise = torch.randn(fake_score_denoised_pred.shape, dtype=fake_score_denoised_pred.dtype, device=fake_score_denoised_pred.device, generator=torch_rng) + + fake_score_denoised_input = add_noise( + fake_score_denoised_pred, + critic_noise, + critic_timestep + ) + + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + fake_score_denoised_output = fake_score_transformer3d( + x=fake_score_denoised_input, + context=prompt_embeds, + t=critic_timestep, + seq_len=seq_len, + y=inpaint_latents if args.train_mode != "normal" else None, + clip_fea=clip_context if args.train_mode != "normal" else None, + ) + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + denoising_loss = custom_mse_loss(fake_score_denoised_output, critic_noise - fake_score_denoised_pred) + avg_denoising_loss = accelerator.gather(denoising_loss.repeat(args.train_batch_size)).mean() + train_denoising_loss += avg_denoising_loss.item() / args.gradient_accumulation_steps + + accelerator_fake_score_transformer3d.backward(denoising_loss) + if accelerator_fake_score_transformer3d.sync_gradients: + accelerator_fake_score_transformer3d.clip_grad_norm_(fake_trainable_params, args.max_grad_norm) + critic_optimizer.step() + fake_score_lr_scheduler.step() + critic_optimizer.zero_grad() + + if args.low_vram: + fake_score_transformer3d = fake_score_transformer3d.to(accelerator.device) + generator_transformer3d = generator_transformer3d.to(accelerator.device) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_denoising_loss": train_denoising_loss, "train_dmd_loss": train_dmd_loss}, step=global_step) + train_dmd_loss = 0.0 + train_denoising_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + fake_score_save_path = os.path.join(save_path, "fake_score") + accelerator.save_state(save_path) + accelerator_fake_score_transformer3d.save_state(fake_score_save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + text_encoder, + tokenizer, + clip_image_encoder, + generator_transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"denoising_loss": denoising_loss.detach().item(), "dmd_loss": dmd_loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + vae, + text_encoder, + tokenizer, + clip_image_encoder, + generator_transformer3d, + args, + config, + accelerator, + weight_dtype, + global_step, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + generator_transformer3d = unwrap_model(generator_transformer3d) + + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + fake_score_save_path = os.path.join(save_path, "fake_score") + accelerator.save_state(save_path) + accelerator_fake_score_transformer3d.save_state(fake_score_save_path) + logger.info(f"Saved state to {save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/wan2.1_self_forcing/train_distill.sh b/scripts/wan2.1_self_forcing/train_distill.sh new file mode 100644 index 00000000..499a8dfb --- /dev/null +++ b/scripts/wan2.1_self_forcing/train_distill.sh @@ -0,0 +1,43 @@ +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --low_vram diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index b040d1bb..419388b1 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -69,6 +69,7 @@ WanSelfAttention, WanTransformer3DModel) from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V +from .wan_transformer3d_self_forcing import WanTransformer3DModel_SelfForcing from .wan_transformer3d_vace import VaceWanTransformer3DModel from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_ from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8 diff --git a/videox_fun/models/wan_transformer3d_self_forcing.py b/videox_fun/models/wan_transformer3d_self_forcing.py new file mode 100644 index 00000000..63770f6c --- /dev/null +++ b/videox_fun/models/wan_transformer3d_self_forcing.py @@ -0,0 +1,1098 @@ +# Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import glob +import json +import math +import os +import types +import warnings +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import is_torch_version, logging +from torch import nn +from torch.nn.attention.flex_attention import (BlockMask, create_block_mask, + flex_attention) + +from ..dist import (get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group, + usp_attn_forward, xFuserLongContextAttention) +from ..utils import cfg_skip +from .attention_utils import attention +from .wan_camera_adapter import SimpleAdapter +from .wan_transformer3d import (MLPProj, WanLayerNorm, WanRMSNorm, + WanTransformer3DModel, rope_apply, rope_params, + sinusoidal_embedding_1d) +import torch._dynamo as dynamo + +if dynamo.config.cache_size_limit < 128: + dynamo.config.cache_size_limit = 128 + +# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention +# see https://github.com/pytorch/pytorch/issues/133254 +# change to default for other models +flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") + + +logger = logging.get_logger(__name__) + + +@amp.autocast(enabled=False) +@torch.compiler.disable() +def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): + """ + Apply causal rotary positional embedding with frame offset support. + + This function applies RoPE with a starting frame offset, enabling causal + inference where different frames can have different positional indices. + + Args: + x: Input tensor with shape (batch, seq_len, n_channels, c*2) + grid_sizes: Grid dimensions (f, h, w) for each sample + freqs: Precomputed frequency parameters + start_frame: Starting frame index for causal positioning + + Returns: + Tensor with causal RoPE applied + """ + n, c = x.size(2), x.size(3) // 2 + + # Split freqs into temporal, height, and width components + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # Process each sample in the batch + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # Reshape and convert to complex numbers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + seq_len, n, -1, 2)) + # Broadcast frequencies with start_frame offset for temporal dimension + freqs_i = torch.cat([ + freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # Apply rotation: x * exp(i*freq) + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + # Concatenate with padding tokens (if any) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # Append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + + +class CasualWanSelfAttention(nn.Module): + """Wan self-attention mechanism with RoPE and optional windowed attention.""" + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + local_attn_size=-1, + sink_size=0): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.max_attention_size = 32760 if local_attn_size == -1 else local_attn_size * 1560 + + # Layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + block_mask, + kv_cache=None, + current_start=0, + cache_start=None, + dtype=torch.bfloat16, + t=0 + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + block_mask (BlockMask): Block mask for flex attention + kv_cache: KV cache for causal self-attention + current_start: Current starting position in token sequence + cache_start: Cache starting position + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + if cache_start is None: + cache_start = current_start + + # Query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + if kv_cache is None: + # Check if this is teacher forcing training (sequence length is doubled) + is_tf = (s == seq_lens[0].item() * 2) + if is_tf: + # Split into clean and noisy parts for teacher forcing + q_chunk = torch.chunk(q, 2, dim=1) + k_chunk = torch.chunk(k, 2, dim=1) + roped_query = [] + roped_key = [] + # Apply same RoPE to both clean and noisy parts + for ii in range(2): + rq = rope_apply(q_chunk[ii], grid_sizes, freqs).type_as(v) + rk = rope_apply(k_chunk[ii], grid_sizes, freqs).type_as(v) + roped_query.append(rq) + roped_key.append(rk) + + roped_query = torch.cat(roped_query, dim=1) + roped_key = torch.cat(roped_key, dim=1) + + # Pad to 128 multiple for flex attention + padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1] + padded_roped_query = torch.cat( + [roped_query, + torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]], + device=q.device, dtype=v.dtype)], + dim=1 + ) + + padded_roped_key = torch.cat( + [roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]], + device=k.device, dtype=v.dtype)], + dim=1 + ) + + padded_v = torch.cat( + [v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]], + device=v.device, dtype=v.dtype)], + dim=1 + ) + + # Apply flex attention with block mask + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask + )[:, :, :-padded_length].transpose(2, 1) + + else: + # Standard inference without teacher forcing + roped_query = rope_apply(q, grid_sizes, freqs).type_as(v) + roped_key = rope_apply(k, grid_sizes, freqs).type_as(v) + + # Pad to 128 multiple for flex attention + padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1] + padded_roped_query = torch.cat( + [roped_query, + torch.zeros([q.shape[0], padded_length, q.shape[2], q.shape[3]], + device=q.device, dtype=v.dtype)], + dim=1 + ) + + padded_roped_key = torch.cat( + [roped_key, torch.zeros([k.shape[0], padded_length, k.shape[2], k.shape[3]], + device=k.device, dtype=v.dtype)], + dim=1 + ) + + padded_v = torch.cat( + [v, torch.zeros([v.shape[0], padded_length, v.shape[2], v.shape[3]], + device=v.device, dtype=v.dtype)], + dim=1 + ) + + # Apply flex attention with block mask + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask + )[:, :, :-padded_length].transpose(2, 1) + else: + # Causal inference with KV cache + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + current_start_frame = current_start // frame_seqlen + # Apply causal RoPE with frame offset + roped_query = causal_rope_apply( + q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + roped_key = causal_rope_apply( + k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) + + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache + kv_cache_size = kv_cache["k"].shape[1] + num_new_tokens = roped_query.shape[1] + + if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and ( + num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size + num_rolled_tokens = kv_cache["local_end_index"].item() - num_evicted_tokens - sink_tokens + kv_cache["k"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + kv_cache["k"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + kv_cache["v"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + kv_cache["v"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + # Insert the new keys/values at the end + local_end_index = kv_cache["local_end_index"].item() + current_end - \ + kv_cache["global_end_index"].item() - num_evicted_tokens + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + else: + # Assign new keys/values directly up to current_end + local_end_index = kv_cache["local_end_index"].item() + current_end - kv_cache["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + # Compute attention with local window + x = attention( + roped_query, + kv_cache["k"][:, max(0, local_end_index - self.max_attention_size):local_end_index], + kv_cache["v"][:, max(0, local_end_index - self.max_attention_size):local_end_index] + ) + kv_cache["global_end_index"].fill_(current_end) + kv_cache["local_end_index"].fill_(local_end_index) + + # Output projection + x = x.flatten(2) + x = self.o(x) + return x + + +class CasualWanT2VCrossAttention(CasualWanSelfAttention): + """Text-to-video cross-attention layer.""" + + def forward(self, x, context, context_lens, crossattn_cache=None, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + crossattn_cache (List[dict], *optional*): Contains the cached key and value tensors for context embedding. + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # Compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + + if crossattn_cache is not None: + # Use cached key/value if available + if not crossattn_cache["is_init"]: + crossattn_cache["is_init"] = True + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + crossattn_cache["k"] = k + crossattn_cache["v"] = v + else: + k = crossattn_cache["k"] + v = crossattn_cache["v"] + else: + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # Compute attention + x = attention(q, k, v, k_lens=context_lens) + + # Output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(CasualWanSelfAttention): + """Image-to-video cross-attention layer with separate image context processing.""" + + def __init__(self, + dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + local_attn_size=-1, + sink_size=0): + super().__init__(dim, num_heads, window_size, qk_norm, eps, local_attn_size, sink_size) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens, crossattn_cache=None, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + # Split context into image and text parts + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # Compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + # Image cross-attention + img_x = attention(q, k_img, v_img, k_lens=None) + # Text cross-attention + x = attention(q, k, v, k_lens=context_lens) + + # Output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +class WanCrossAttention(CasualWanSelfAttention): + """Generic cross-attention layer.""" + + def forward(self, x, context, context_lens, crossattn_cache=None, dtype=torch.bfloat16, t=0): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + # Compute query, key, value + q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d) + k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d) + v = self.v(context.to(dtype)).view(b, -1, n, d) + # Compute attention + x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens) + # Output + x = x.flatten(2) + x = self.o(x.to(dtype)) + return x + + +# Define local cross-attention classes mapping +WAN_SELF_FORCING_CROSSATTENTION_CLASSES = { + 't2v_cross_attn': CasualWanT2VCrossAttention, + 'i2v_cross_attn': WanI2VCrossAttention, + 'cross_attn': WanCrossAttention, +} + + +class CasualWanAttentionBlock(nn.Module): + """Wan transformer block with self-attention, cross-attention, and FFN.""" + + def __init__(self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + local_attn_size=-1, + sink_size=0): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # Layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = CasualWanSelfAttention(dim, num_heads, window_size, qk_norm, + eps, local_attn_size, sink_size) + self.norm3 = WanLayerNorm( + dim, eps, + elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_SELF_FORCING_CROSSATTENTION_CLASSES[cross_attn_type](dim, + num_heads, + (-1, -1), + qk_norm, + eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential( + nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), + nn.Linear(ffn_dim, dim)) + + # Modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + kv_cache=None, + crossattn_cache=None, + current_start=0, + cache_start=None, + block_mask=None, + dtype=torch.bfloat16, + t=0, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] or [B, L, 6, C] for modulation + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + context(Tensor): Shape [B, L_context, C] + context_lens(Tensor): Shape [B] + kv_cache: KV cache for causal self-attention + crossattn_cache: Cross-attention cache + current_start: Current starting position in token sequence + cache_start: Cache starting position + block_mask: Block mask for flex attention + """ + num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1] + e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2) + + # Self-attention with modulation + y = self.self_attn( + (self.norm1(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0]).flatten(1, 2), + seq_lens, grid_sizes, + freqs, block_mask, kv_cache, current_start, cache_start) + + # Residual connection with modulation + x = x + (y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[2]).flatten(1, 2) + + # Cross-attention and FFN function + def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None): + x = x + self.cross_attn(self.norm3(x), context, + context_lens, crossattn_cache=crossattn_cache) + y = self.ffn( + (self.norm2(x).unflatten(dim=1, sizes=(num_frames, + frame_seqlen)) * (1 + e[4]) + e[3]).flatten(1, 2) + ) + x = x + (y.unflatten(dim=1, sizes=(num_frames, + frame_seqlen)) * e[5]).flatten(1, 2) + return x + + x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache) + return x + + +class CausalHead(nn.Module): + """ + Causal head with per-frame modulation for Self-Forcing inference. + + Unlike the base Head class which expects [B, C] timestep embeddings, + CausalHead expects [B, F, 1, C] per-frame timestep embeddings and applies + modulation independently to each frame before the final projection. + """ + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # Layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # Modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, F, 1, C] + """ + num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1] + + e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2) + # Apply modulation per frame and project to output + x = (self.head(self.norm(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1]) + e[0])) + return x + + +class WanTransformer3DModel_SelfForcing(WanTransformer3DModel): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + # _no_split_modules = ['CasualWanAttentionBlock'] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + in_channels=16, + hidden_size=2048, + add_control_adapter=False, + in_dim_control_adapter=24, + downscale_factor_control_adapter=8, + add_ref_conv=False, + in_dim_ref_conv=16, + cross_attn_type=None, + + # Self-Forcing causal inference parameters + local_attn_size=-1, + sink_size=0, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to True): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + in_channels (`int`, *optional*, defaults to 16): + Alias for in_dim (diffusers compatibility) + hidden_size (`int`, *optional*, defaults to 2048): + Alias for dim (diffusers compatibility) + add_control_adapter (`bool`, *optional*, defaults to False): + Enable camera control adapter + in_dim_control_adapter (`int`, *optional*, defaults to 24): + Input channels for control adapter + downscale_factor_control_adapter (`int`, *optional*, defaults to 8): + Downscale factor for control adapter + add_ref_conv (`bool`, *optional*, defaults to False): + Enable reference frame convolution + in_dim_ref_conv (`int`, *optional*, defaults to 16): + Input channels for reference convolution + cross_attn_type (`str`, *optional*, defaults to None): + Cross-attention type, auto-determined from model_type if None + local_attn_size (`int`, *optional*, defaults to -1): + Local attention window size (-1 for global attention) + sink_size (`int`, *optional*, defaults to 0): + Sink token size for local attention + """ + + super().__init__( + model_type=model_type, + patch_size=patch_size, + text_len=text_len, + in_dim=in_dim, + dim=dim, + ffn_dim=ffn_dim, + freq_dim=freq_dim, + text_dim=text_dim, + out_dim=out_dim, + num_heads=num_heads, + num_layers=num_layers, + window_size=window_size, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + in_channels=in_channels, + hidden_size=hidden_size, + add_control_adapter=add_control_adapter, + in_dim_control_adapter=in_dim_control_adapter, + downscale_factor_control_adapter=downscale_factor_control_adapter, + add_ref_conv=add_ref_conv, + in_dim_ref_conv=in_dim_ref_conv, + cross_attn_type=cross_attn_type + ) + # Blocks + if cross_attn_type is None: + cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' + self.blocks = nn.ModuleList([ + CasualWanAttentionBlock( + cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, local_attn_size, sink_size + ) + for _ in range(num_layers) + ]) + for layer_idx, block in enumerate(self.blocks): + block.self_attn.layer_idx = layer_idx + block.self_attn.num_layers = self.num_layers + + # Head + self.head = CausalHead(dim, out_dim, patch_size, eps) + + # Self-forcing causal inference state + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.block_mask = None + self.num_frame_per_block = 1 + self.independent_first_frame = False + + # Other parameters + self.gradient_checkpointing = False + self.all_gather = None + self.sp_world_size = 1 + self.sp_world_rank = 0 + self.init_weights() + + def _set_gradient_checkpointing(self, *args, **kwargs): + if "value" in kwargs: + self.gradient_checkpointing = kwargs["value"] + if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): + self.motioner.gradient_checkpointing = kwargs["value"] + elif "enable" in kwargs: + self.gradient_checkpointing = kwargs["enable"] + if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"): + self.motioner.gradient_checkpointing = kwargs["enable"] + else: + raise ValueError("Invalid set gradient checkpointing") + + def create_block_mask_for_training( + self, + num_frames: int, + frame_seqlen: int, + num_frame_per_block: int = 1, + independent_first_frame: bool = False, + device: torch.device | str = "cpu" + ): + """ + Create block-wise causal mask for Self-Forcing training. + + This creates a mask where each block can only attend to previous blocks, + implementing causal self-attention without KV cache (using flex attention). + + Args: + num_frames: Number of frames in the video + frame_seqlen: Sequence length per frame (H * W / patch_size^2) + num_frame_per_block: Number of frames per causal block + independent_first_frame: If True, first frame is independent [1, N, N, ...] + device: Device to create the mask on + """ + total_length = num_frames * frame_seqlen + + # Right padding to multiple of 128 for flex attention + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + + if not independent_first_frame: + # Standard block pattern: [N, N, N, ...] + frame_indices = torch.arange( + start=0, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device + ) + + for tmp in frame_indices: + ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block + else: + # Independent first frame pattern: [1, N, N, ...] + # First frame + ends[:frame_seqlen] = frame_seqlen + # Remaining blocks + frame_indices = torch.arange( + start=frame_seqlen, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device + ) + for tmp in frame_indices: + ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block + + def attention_mask(b, h, q_idx, kv_idx): + if self.local_attn_size == -1: + # Global block-wise causal: can attend to all previous blocks + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + else: + # Local attention: limited window + return ((kv_idx < ends[q_idx]) & (kv_idx >= (ends[q_idx] - self.local_attn_size * frame_seqlen))) | (q_idx == kv_idx) + + from torch.nn.attention.flex_attention import create_block_mask + + self.block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device + ) + + # Store parameters for future reference + self.num_frame_per_block = num_frame_per_block + self.independent_first_frame = independent_first_frame + + def create_teacher_forcing_mask( + self, + device: torch.device | str, + num_frames: int, + frame_seqlen: int, + num_frame_per_block: int = 1 + ) -> BlockMask: + """ + Create block-wise teacher forcing mask for Self-Forcing training. + + This creates a mask where: + - Clean frames (first half): causal attention within clean sequence + - Noisy frames (second half): attend to all preceding clean frames + causal within noisy + + Sequence layout: [clean_frame_1, clean_frame_2, ..., noisy_frame_1, noisy_frame_2, ...] + + Args: + device: Device to create the mask on + num_frames: Number of frames in the video + frame_seqlen: Sequence length per frame (H * W / patch_size^2) + num_frame_per_block: Number of frames per causal block + + Returns: + BlockMask for flex attention + """ + total_length = num_frames * frame_seqlen * 2 # Clean + noisy + + # Right padding to multiple of 128 for flex attention + padded_length = math.ceil(total_length / 128) * 128 - total_length + + clean_ends = num_frames * frame_seqlen + + # For clean context frames: [start, end] interval + context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + + # For noisy frames: need two intervals [context_start, context_end] + [noisy_start, noisy_end] + noise_context_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + noise_context_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + noise_noise_starts = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + noise_noise_ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + + # Block-wise causal mask + attention_block_size = frame_seqlen * num_frame_per_block + frame_indices = torch.arange( + start=0, + end=num_frames * frame_seqlen, + step=attention_block_size, + device=device, dtype=torch.long + ) + + # Clean frames: causal attention + for start in frame_indices: + context_ends[start:start + attention_block_size] = start + attention_block_size + + # Noisy frames: start positions + noisy_image_start_list = torch.arange( + num_frames * frame_seqlen, total_length, + step=attention_block_size, + device=device, dtype=torch.long + ) + noisy_image_end_list = noisy_image_start_list + attention_block_size + + # Noisy frames mask configuration + for block_index, (start, end) in enumerate(zip(noisy_image_start_list, noisy_image_end_list)): + # Attend to noisy tokens within the same block + noise_noise_starts[start:end] = start + noise_noise_ends[start:end] = end + # Attend to context tokens in previous blocks + noise_context_ends[start:end] = block_index * attention_block_size + + def attention_mask(b, h, q_idx, kv_idx): + # Clean frames mask + clean_mask = (q_idx < clean_ends) & (kv_idx < context_ends[q_idx]) + # Noisy frames mask: attend to clean + self + C1 = (kv_idx < noise_noise_ends[q_idx]) & (kv_idx >= noise_noise_starts[q_idx]) + C2 = (kv_idx < noise_context_ends[q_idx]) & (kv_idx >= noise_context_starts[q_idx]) + noise_mask = (q_idx >= clean_ends) & (C1 | C2) + + eye_mask = q_idx == kv_idx + return eye_mask | clean_mask | noise_mask + + block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device + ) + + return block_mask + + def forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, + kv_cache: dict = None, + crossattn_cache: dict = None, + current_start: int = 0, + cache_start: int = 0, + clean_x=None, + aug_t=None, + ): + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # Params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # Embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + + # Handle teacher forcing: concatenate clean and noisy features + if clean_x is not None: + clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x] + clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x] + + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + logger.debug(f"[SelfForcing] seq_len: {seq_len}, seq_lens: {seq_lens}") + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_lens[0] - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # Concatenate clean features for teacher forcing + if clean_x is not None: + seq_lens_clean = torch.tensor([u.size(1) for u in clean_x], dtype=torch.long) + assert seq_lens_clean.max() <= seq_len + clean_x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_lens_clean[0] - u.size(1), u.size(2))], dim=1) for u in clean_x + ]) + x = torch.cat([clean_x, x], dim=1) + + # Manage block mask for training (kv_cache is None during training). + # We must recreate the mask when switching between normal training + # (causal mask) and teacher forcing (clean+noisy mask), because the + # two modes expect different sequence lengths. + if kv_cache is None: + num_frames_actual = grid_sizes[0, 0].item() + frame_seqlen_actual = grid_sizes[0, 1].item() * grid_sizes[0, 2].item() + + if clean_x is not None: + expected_mask_len = num_frames_actual * frame_seqlen_actual * 2 + is_teacher_forcing_mask = True + else: + expected_mask_len = num_frames_actual * frame_seqlen_actual + is_teacher_forcing_mask = False + + if (self.block_mask is None or + getattr(self, '_block_mask_expected_len', None) != expected_mask_len or + getattr(self, '_block_mask_is_teacher_forcing', None) != is_teacher_forcing_mask): + + if is_teacher_forcing_mask: + if self.independent_first_frame: + raise NotImplementedError("Teacher forcing with independent first frame is not supported") + self.block_mask = self.create_teacher_forcing_mask( + device=device, + num_frames=num_frames_actual, + frame_seqlen=frame_seqlen_actual, + num_frame_per_block=self.num_frame_per_block + ) + else: + self.create_block_mask_for_training( + num_frames=num_frames_actual, + frame_seqlen=frame_seqlen_actual, + num_frame_per_block=self.num_frame_per_block, + independent_first_frame=self.independent_first_frame, + device=device + ) + self._block_mask_expected_len = expected_mask_len + self._block_mask_is_teacher_forcing = is_teacher_forcing_mask + + # Time embeddings + # Ensure t is 2D [B, num_frames] to align with inference behavior. + # Training passes 1D t=[B] while inference passes 2D t=[B, F]. + # Without this, e0 shape differs and CasualWanAttentionBlock groups + # tokens by modulation dim (6) instead of actual frames (F). + if t.dim() == 1: + num_frames_actual = grid_sizes[0, 0].item() + t = t.unsqueeze(1).expand(-1, num_frames_actual) + + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x)) + e0 = self.time_projection(e).unflatten( + 1, (6, self.dim)).unflatten(dim=0, sizes=t.shape) + + # Handle teacher forcing: concatenate clean and noisy time embeddings + if clean_x is not None: + if aug_t is None: + aug_t = torch.zeros_like(t) + if aug_t.dim() == 1: + aug_t = aug_t.unsqueeze(1).expand(-1, num_frames_actual) + e_clean = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, aug_t.flatten()).type_as(x)) + e0_clean = self.time_projection(e_clean).unflatten( + 1, (6, self.dim)).unflatten(dim=0, sizes=t.shape) + e0 = torch.cat([e0_clean, e0], dim=1) + + # context: text embeddings (padded to fixed length) + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # Shape: [B, 257, dim] + context = torch.concat([context_clip, context], dim=1) + + # Arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens, + block_mask=self.block_mask + ) + + def create_custom_forward(module): + def custom_forward(*inputs, **kwargs): + return module(*inputs, **kwargs) + return custom_forward + + for block_index, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + kwargs.update( + { + "kv_cache": kv_cache[block_index] if kv_cache else None, + "current_start": current_start, + "cache_start": cache_start + } + ) + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, **kwargs, + use_reentrant=False, + ) + else: + kwargs.update( + { + "kv_cache": kv_cache[block_index] if kv_cache else None, + "crossattn_cache": crossattn_cache[block_index] if crossattn_cache else None, + "current_start": current_start, + "cache_start": cache_start + } + ) + x = block(x, **kwargs) + + # Remove clean part for teacher forcing output + if clean_x is not None: + x = x[:, x.shape[1] // 2:] + + # Head: project to output space + x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2)) + # Unpatchify: reconstruct video from patches + x = self.unpatchify(x, grid_sizes) + return torch.stack(x) \ No newline at end of file diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py index 6ecec533..65d97b67 100755 --- a/videox_fun/pipeline/__init__.py +++ b/videox_fun/pipeline/__init__.py @@ -31,6 +31,7 @@ from .pipeline_wan_fun_control import WanFunControlPipeline from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline from .pipeline_wan_phantom import WanFunPhantomPipeline +from .pipeline_wan_self_forcing import WanSelfForcingPipeline from .pipeline_wan_vace import WanVacePipeline from .pipeline_z_image import ZImagePipeline from .pipeline_z_image_control import ZImageControlPipeline diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py new file mode 100644 index 00000000..821b884a --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -0,0 +1,842 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanSelfForcingPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanSelfForcingPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.kv_cache_pos = None + self.kv_cache_neg = None + self.crossattn_cache_pos = None + self.crossattn_cache_neg = None + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Shape: [B, C, F, H, W] (standard PyTorch format) + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + def _initialize_kv_cache(self, batch_size, dtype, device, frame_seq_length): + """ + Initialize KV cache for causal self-attention. + """ + kv_cache_pos = [] + kv_cache_neg = [] + # Use the default KV cache size (32760 tokens for global attention) + local_attn_size = getattr(self.transformer.config, 'local_attn_size', -1) + if local_attn_size != -1: + kv_cache_size = local_attn_size * frame_seq_length + else: + kv_cache_size = 32760 + + num_heads = self.transformer.config.num_heads + head_dim = self.transformer.config.dim // num_heads + + for _ in range(self.transformer.config.num_layers): + kv_cache_pos.append({ + "k": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "global_end_index": torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": torch.tensor([0], dtype=torch.long, device=device) + }) + kv_cache_neg.append({ + "k": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, kv_cache_size, num_heads, head_dim], dtype=dtype, device=device), + "global_end_index": torch.tensor([0], dtype=torch.long, device=device), + "local_end_index": torch.tensor([0], dtype=torch.long, device=device) + }) + + self.kv_cache_pos = kv_cache_pos + self.kv_cache_neg = kv_cache_neg + + def _initialize_crossattn_cache(self, batch_size, dtype, device): + """ + Initialize cross-attention cache. + """ + crossattn_cache_pos = [] + crossattn_cache_neg = [] + text_len = self.transformer.config.text_len + num_heads = self.transformer.config.num_heads + head_dim = self.transformer.config.dim // num_heads + + for _ in range(self.transformer.config.num_layers): + crossattn_cache_pos.append({ + "k": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "is_init": False + }) + crossattn_cache_neg.append({ + "k": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "v": torch.zeros([batch_size, text_len, num_heads, head_dim], dtype=dtype, device=device), + "is_init": False + }) + + self.crossattn_cache_pos = crossattn_cache_pos + self.crossattn_cache_neg = crossattn_cache_neg + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + initial_latent: Optional[torch.FloatTensor] = None, + start_frame_index: int = 0, + num_frame_per_block: int = 1, + independent_first_frame: bool = True, + context_noise: int = 0, + ) -> Union[WanSelfForcingPipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for Self-Forcing causal generation. + + Args: + initial_latent: Optional initial latent frames for I2V/video extension. + Shape: (batch_size, num_input_frames, channels, height, width) + start_frame_index: Starting frame index for long video generation. + Used when continuing generation from a previous segment. + num_frame_per_block: Number of frames to generate per block. + independent_first_frame: Whether to generate the first frame independently (T2V mode). + context_noise: Context noise level for KV cache update (matches training config). + + Examples: + ```python + pass + ``` + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + self.scheduler.sigma_min = 0.0 + self.scheduler.config.shift_terminal = 0.625 + + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents (noise) and output buffer separately + latent_channels = self.transformer.config.in_channels + + # For I2V: num_frames_to_generate is frames to generate (not including input frames) + num_frames_to_generate = num_frames + if initial_latent is not None: + # In I2V mode, num_frames is total frames, but noise should only be the new frames + # VAE compression: num_latent_frames = (num_frames - 1) // temporal_compression + 1 + num_input_frames_temp = initial_latent.shape[2] # [B, C, F, H, W] + total_latent_frames = (num_frames - 1) // self.vae.temporal_compression_ratio + 1 + input_latent_frames = num_input_frames_temp + num_frames_to_generate = total_latent_frames - input_latent_frames + + # Prepare noise (only for frames to generate) + noise = self.prepare_latents( + batch_size, + latent_channels, + num_frames_to_generate, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + + # Calculate total output frames (input + generated) + num_input_frames = initial_latent.shape[2] if initial_latent is not None else 0 # [B, C, F, H, W] + num_output_frames = num_frames_to_generate + num_input_frames + + # Allocate output buffer: [B, C, F_total, H, W] + output = torch.zeros_like( + noise, + device=device, + dtype=weight_dtype + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Calculate sequence length and frame_seq_length + target_shape = ( + self.vae.latent_channels, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + width // self.vae.spatial_compression_ratio, + height // self.vae.spatial_compression_ratio, + ) + seq_len = math.ceil( + (target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) + * target_shape[1] + ) + + # Calculate frame_seq_length: tokens per frame + frame_seq_length = (target_shape[2] * target_shape[3]) // (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) + + # 7. Causal generation loop - block by block + # num_latent_frames is the number of frames after VAE compression + num_latent_frames = target_shape[1] + + # Determine num_blocks based on mode (T2V vs I2V) + # Reference: causal_inference.py line 70-78 + if not independent_first_frame or (independent_first_frame and initial_latent is not None): + # I2V mode: even with independent_first_frame, if initial_latent is provided, frames should be divisible + assert num_latent_frames % num_frame_per_block == 0, \ + f"num_latent_frames ({num_latent_frames}) must be divisible by num_frame_per_block ({num_frame_per_block})" + num_blocks = num_latent_frames // num_frame_per_block + else: + # T2V mode: no initial_latent, use [1, 4, 4, ...] pattern + assert (num_latent_frames - 1) % num_frame_per_block == 0, \ + f"num_latent_frames-1 ({num_latent_frames - 1}) must be divisible by num_frame_per_block ({num_frame_per_block})" + num_blocks = (num_latent_frames - 1) // num_frame_per_block + + # Self-Forcing causal state (reset per call) + current_start_frame = start_frame_index + cache_start_frame = 0 + + # 8. Initialize KV cache and cross-attention cache + # Reset caches if they exist (for multiple inference calls) + if self.kv_cache_pos is not None: + for block_index in range(len(self.kv_cache_pos)): + self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor( + [0], dtype=torch.long, device=device) + self.kv_cache_pos[block_index]["local_end_index"] = torch.tensor( + [0], dtype=torch.long, device=device) + self.kv_cache_neg[block_index]["global_end_index"] = torch.tensor( + [0], dtype=torch.long, device=device) + self.kv_cache_neg[block_index]["local_end_index"] = torch.tensor( + [0], dtype=torch.long, device=device) + for block_index in range(len(self.crossattn_cache_pos)): + self.crossattn_cache_pos[block_index]["is_init"] = False + self.crossattn_cache_neg[block_index]["is_init"] = False + else: + self._initialize_kv_cache(batch_size=batch_size, dtype=weight_dtype, device=device, frame_seq_length=frame_seq_length) + self._initialize_crossattn_cache(batch_size=batch_size, dtype=weight_dtype, device=device) + + # Build all_num_frames list + # Self-Forcing: T2V with independent_first_frame uses [1, 4, 4, 4, ...] pattern + # I2V mode uses [4, 4, 4, ...] pattern (first frame is provided) + all_num_frames = [num_frame_per_block] * num_blocks + if independent_first_frame and initial_latent is None: + # First frame is generated independently (standard Self-Forcing T2V pattern) + all_num_frames = [1] + all_num_frames + + for block_idx, current_num_frames in enumerate(all_num_frames): + # Extract noise for current block and convert to list format + # noise only contains frames to generate, indexed from 0 + # current_start_frame tracks global position (including input frames for I2V) + # Need to offset by num_input_frames to get index in noise + start_idx = current_start_frame - num_input_frames + end_idx = start_idx + current_num_frames + noisy_input = noise[:, :, start_idx:end_idx] + + # Denoising loop for current block + # Reset scheduler state for each block (required for causal generation) + # For Euler scheduler, resetting _step_index is sufficient. + # For multi-step schedulers (UniPC, DPM++), also clear accumulated model outputs. + self.scheduler._step_index = None + if hasattr(self.scheduler, 'model_outputs'): + self.scheduler.model_outputs = [] + + for step_idx, t in enumerate(timesteps): + + # Per-frame timesteps for causal generation + timestep = torch.ones([batch_size, current_num_frames], device=device, dtype=torch.long) * t + + if do_classifier_free_guidance: + # Conditional path + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred_cond = self.transformer( + x=noisy_input, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # Unconditional path + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred_uncond = self.transformer( + x=noisy_input, + context=negative_prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_neg, + crossattn_cache=self.crossattn_cache_neg, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # CFG guidance + # Transformer output shape check + if flow_pred_cond.dim() == 5: + # Already [B, C, F, H, W] + flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) + elif flow_pred_cond.dim() == 4: + # [F, C, H, W], need to add batch dim + flow_pred_cond = flow_pred_cond.unsqueeze(0).permute(0, 2, 1, 3, 4) + flow_pred_uncond = flow_pred_uncond.unsqueeze(0).permute(0, 2, 1, 3, 4) + flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) + else: + raise ValueError(f"Unexpected flow_pred_cond dim: {flow_pred_cond.dim()}, shape: {flow_pred_cond.shape}") + else: + # Forward pass with KV cache + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred = self.transformer( + x=noisy_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # Transformer output shape check + if flow_pred.dim() == 4: + # [F, C, H, W], need to add batch dim and permute + flow_pred = flow_pred.unsqueeze(0).permute(0, 2, 1, 3, 4) + # If already 5D [B, C, F, H, W], no need to permute + + # Get current sigma for x0 conversion + sigma_t = self.scheduler.sigmas[step_idx] + + # Convert to x0: x0 = x_t - sigma_t * flow_pred (matches original wan_wrapper.py line 192) + denoised_pred = noisy_input - sigma_t * flow_pred # [B*F, C, H, W] + + if step_idx < len(timesteps) - 1: + # Not the last step: add noise for next timestep + next_t = timesteps[step_idx + 1] + + # Add noise using flow matching formula: x_{t+1} = (1-sigma_{t+1}) * x0 + sigma_{t+1} * noise + next_sigma = self.scheduler.sigmas[step_idx + 1] + local_noise = torch.randn(denoised_pred.shape, device=denoised_pred.device, dtype=denoised_pred.dtype, generator=generator) + noisy_input = (1 - next_sigma) * denoised_pred + next_sigma * local_noise + else: + noisy_input = denoised_pred + + # Update output with denoised block + output[:, :, cache_start_frame:cache_start_frame + current_num_frames] = denoised_pred + + # Update KV cache with clean context (timestep=context_noise) for next block + # Reference: causal_inference.py line 227 - uses context_noise for KV cache update + if block_idx < len(all_num_frames) - 1: + context_timestep = torch.ones([batch_size, current_num_frames], device=device, dtype=torch.long) * context_noise + + if do_classifier_free_guidance: + # Update both positive and negative caches + with torch.cuda.amp.autocast(dtype=weight_dtype): + self.transformer( + x=denoised_pred, + context=prompt_embeds, + t=context_timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + self.transformer( + x=denoised_pred, + context=negative_prompt_embeds, + t=context_timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_neg, + crossattn_cache=self.crossattn_cache_neg, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + else: + with torch.cuda.amp.autocast(dtype=weight_dtype): + self.transformer( + x=denoised_pred, + context=in_prompt_embeds, + t=context_timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + current_start_frame += current_num_frames + cache_start_frame += current_num_frames + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, block_idx, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + if comfyui_progressbar: + pbar.update(1) + + # 9. Decode output + + if output_type == "pil": + video = self.decode_latents(output) + video = torch.from_numpy(video) + else: + video = output + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanSelfForcingPipelineOutput(videos=video) From f1df693076eab04f07dfac177d546e7998ec9c82 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 6 May 2026 14:48:31 +0800 Subject: [PATCH 02/16] Update Self-Forcing sh --- scripts/wan2.1_self_forcing/train_distill.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/wan2.1_self_forcing/train_distill.sh b/scripts/wan2.1_self_forcing/train_distill.sh index 499a8dfb..dfdb5290 100644 --- a/scripts/wan2.1_self_forcing/train_distill.sh +++ b/scripts/wan2.1_self_forcing/train_distill.sh @@ -14,6 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --image_sample_size=640 \ --video_sample_size=640 \ --token_sample_size=640 \ + --fix_sample_size 480 832 \ --video_sample_stride=2 \ --video_sample_n_frames=81 \ --train_batch_size=1 \ From ddf64409c34ae6fc39d437182f6b0dd695f91f98 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 6 May 2026 14:50:11 +0800 Subject: [PATCH 03/16] Update Modified from in Self-Forcing model. --- videox_fun/models/wan_transformer3d_self_forcing.py | 2 +- videox_fun/pipeline/pipeline_wan_self_forcing.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/videox_fun/models/wan_transformer3d_self_forcing.py b/videox_fun/models/wan_transformer3d_self_forcing.py index 63770f6c..e0e1de38 100644 --- a/videox_fun/models/wan_transformer3d_self_forcing.py +++ b/videox_fun/models/wan_transformer3d_self_forcing.py @@ -1,6 +1,6 @@ +# Modified from https://github.com/guandeh17/Self-Forcing/blob/main/wan/modules/causal_model.py # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. - import glob import json import math diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py index 821b884a..174cc9db 100644 --- a/videox_fun/pipeline/pipeline_wan_self_forcing.py +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -1,3 +1,4 @@ +# Modified from https://github.com/guandeh17/Self-Forcing/blob/main/pipeline/causal_diffusion_inference.py import inspect import math from dataclasses import dataclass From 31887579404659867680971d7b8c29eaa478d494 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 6 May 2026 14:55:13 +0800 Subject: [PATCH 04/16] Update Self-Forcing inference --- examples/wan2.1_self_forcing/predict_t2v.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/wan2.1_self_forcing/predict_t2v.py b/examples/wan2.1_self_forcing/predict_t2v.py index 483a5dbe..b8d48cb0 100644 --- a/examples/wan2.1_self_forcing/predict_t2v.py +++ b/examples/wan2.1_self_forcing/predict_t2v.py @@ -241,9 +241,4 @@ def save_results(): video_path = os.path.join(save_path, prefix + ".mp4") save_videos_grid(sample, video_path, fps=fps) -if ulysses_degree * ring_degree > 1: - import torch.distributed as dist - if dist.get_rank() == 0: - save_results() -else: - save_results() \ No newline at end of file +save_results() \ No newline at end of file From 78d118d0a7941531e2e172e681fba9d66fa3145e Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 08:54:24 +0800 Subject: [PATCH 05/16] Update subject_image_path --- videox_fun/data/dataset_image_video.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/videox_fun/data/dataset_image_video.py b/videox_fun/data/dataset_image_video.py index 26f75772..674d7e46 100755 --- a/videox_fun/data/dataset_image_video.py +++ b/videox_fun/data/dataset_image_video.py @@ -494,7 +494,8 @@ def get_batch(self, idx): shuffle(subject_id) subject_images = [] for i in range(min(len(subject_id), 4)): - subject_image = Image.open(subject_id[i]) + subject_image_path = subject_id[i] if self.data_root is None else os.path.join(self.data_root, subject_id[i]) + subject_image = Image.open(subject_image_path) if self.padding_subject_info: img = padding_image(subject_image, visual_width, visual_height) @@ -547,7 +548,8 @@ def get_batch(self, idx): shuffle(subject_id) subject_images = [] for i in range(min(len(subject_id), 4)): - subject_image = Image.open(subject_id[i]).convert('RGB') + subject_image_path = subject_id[i] if self.data_root is None else os.path.join(self.data_root, subject_id[i]) + subject_image = Image.open(subject_image_path).convert('RGB') if self.padding_subject_info: img = padding_image(subject_image, visual_width, visual_height) From c2803b37458c59d8936afd90eb56bfaacc613193 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 12:05:21 +0800 Subject: [PATCH 06/16] Update multi gpus in self-forcing --- examples/wan2.1_self_forcing/predict_t2v.py | 30 +++- videox_fun/dist/__init__.py | 9 +- videox_fun/dist/wan_xfuser.py | 154 ++++++++++++++++++ .../models/wan_transformer3d_self_forcing.py | 96 ++++++----- .../pipeline/pipeline_wan_self_forcing.py | 4 +- 5 files changed, 247 insertions(+), 46 deletions(-) diff --git a/examples/wan2.1_self_forcing/predict_t2v.py b/examples/wan2.1_self_forcing/predict_t2v.py index b8d48cb0..9208b9c0 100644 --- a/examples/wan2.1_self_forcing/predict_t2v.py +++ b/examples/wan2.1_self_forcing/predict_t2v.py @@ -45,6 +45,15 @@ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, # resulting in slower speeds but saving a large amount of GPU memory. GPU_memory_mode = "sequential_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = True # Compile will give a speedup in fixed resolution and need a little GPU memory. # The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. compile_dit = False @@ -90,7 +99,7 @@ lora_weight = 0.55 save_path = "samples/wan-videos-self-forcing-t2v" -device = set_multi_gpus_devices(1, 1) +device = set_multi_gpus_devices(ulysses_degree, ring_degree) config = OmegaConf.load(config_path) # Load transformer with causal inference support if enabled @@ -171,6 +180,18 @@ scheduler=scheduler, ) +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) + pipeline.text_encoder = shard_fn(pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + if compile_dit: for i in range(len(pipeline.transformer.blocks)): pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i]) @@ -241,4 +262,9 @@ def save_results(): video_path = os.path.join(save_path, prefix + ".mp4") save_videos_grid(sample, video_path, fps=fps) -save_results() \ No newline at end of file +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/videox_fun/dist/__init__.py b/videox_fun/dist/__init__.py index f6f1c583..31b9c976 100755 --- a/videox_fun/dist/__init__.py +++ b/videox_fun/dist/__init__.py @@ -8,9 +8,9 @@ from .fuser import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group, get_world_group, init_distributed_environment, - initialize_model_parallel, sequence_parallel_all_gather, - sequence_parallel_chunk, set_multi_gpus_devices, - xFuserLongContextAttention) + initialize_model_parallel, model_parallel_is_initialized, + sequence_parallel_all_gather, sequence_parallel_chunk, + set_multi_gpus_devices, xFuserLongContextAttention) from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0 from .infinitalk_xfuser import usp_attn_infinitetalk_forward from .longcatvideo_xfuser import (usp_attn_longcatvideo_avatar_forward, @@ -20,7 +20,8 @@ from .ltx2_xfuser import (LTX2MultiGPUsAttnProcessor, LTX2PerturbedMultiGPUsAttnProcessor) from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0 -from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward +from .wan_xfuser import (usp_attn_forward, usp_attn_s2v_forward, + usp_attn_self_forcing_forward) from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor # The pai_fuser is an internally developed acceleration package, which can be used on PAI. diff --git a/videox_fun/dist/wan_xfuser.py b/videox_fun/dist/wan_xfuser.py index d54adc5c..6c2358c3 100755 --- a/videox_fun/dist/wan_xfuser.py +++ b/videox_fun/dist/wan_xfuser.py @@ -1,3 +1,5 @@ +import math + import torch import torch.cuda.amp as amp @@ -61,6 +63,55 @@ def rope_apply(x, grid_sizes, freqs): output.append(x_i) return torch.stack(output).to(dtype) +@amp.autocast(enabled=False) +@torch.compiler.disable() +def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): + """ + Apply causal rotary positional embedding with frame offset support. + + This function applies RoPE with a starting frame offset, enabling causal + inference where different frames can have different positional indices. + + Args: + x: Input tensor with shape (batch, seq_len, n_channels, c*2) + grid_sizes: Grid dimensions (f, h, w) for each sample + freqs: Precomputed frequency parameters + start_frame: Starting frame index for causal positioning + + Returns: + Tensor with causal RoPE applied + """ + n, c = x.size(2), x.size(3) // 2 + + # Split freqs into temporal, height, and width components + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # Process each sample in the batch + output = [] + + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # Reshape and convert to complex numbers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + seq_len, n, -1, 2)) + # Broadcast frequencies with start_frame offset for temporal dimension + freqs_i = torch.cat([ + freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], + dim=-1).reshape(seq_len, 1, -1) + + # Apply rotation: x * exp(i*freq) + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + # Concatenate with padding tokens (if any) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # Append to collection + output.append(x_i) + return torch.stack(output).type_as(x) + def rope_apply_qk(q, k, grid_sizes, freqs): q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) @@ -178,4 +229,107 @@ def qkv_fn(x): # output x = x.flatten(2) x = self.o(x) + return x + +def usp_attn_self_forcing_forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + block_mask, + kv_cache=None, + current_start=0, + cache_start=None, + dtype=torch.bfloat16, + t=0 +): + """ + USP attention forward for Self-Forcing with KV cache support. + Combines sequence parallelism with causal KV cache inference. + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + half_dtypes = (torch.float16, torch.bfloat16) + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + if cache_start is None: + cache_start = current_start + + # QKV computation + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + # Inference mode with KV cache + frame_seqlen = math.prod(grid_sizes[0][1:]).item() + current_start_frame = current_start // frame_seqlen + + # Step 1: all_gather QKV to restore full sequence + q_full = get_sp_group().all_gather(q, dim=1) # [B, L_full, H, D] + k_full = get_sp_group().all_gather(k, dim=1) + v_full = get_sp_group().all_gather(v, dim=1) + + # Step 2: apply causal RoPE on full sequence with frame offset + roped_query_full = causal_rope_apply(q_full, grid_sizes, freqs, + start_frame=current_start_frame).type_as(v_full) + roped_key_full = causal_rope_apply(k_full, grid_sizes, freqs, + start_frame=current_start_frame).type_as(v_full) + + current_end = current_start + roped_query_full.shape[1] + sink_tokens = self.sink_size * frame_seqlen + kv_cache_size = kv_cache["k"].shape[1] + num_new_tokens = roped_query_full.shape[1] + + # Step 3: KV cache update logic with full keys + if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and \ + (num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size): + num_evicted_tokens = num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size + num_rolled_tokens = kv_cache["local_end_index"].item() - num_evicted_tokens - sink_tokens + kv_cache["k"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + kv_cache["k"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + kv_cache["v"][:, sink_tokens:sink_tokens + num_rolled_tokens] = \ + kv_cache["v"][:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone() + local_end_index = kv_cache["local_end_index"].item() + current_end - \ + kv_cache["global_end_index"].item() - num_evicted_tokens + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key_full + kv_cache["v"][:, local_start_index:local_end_index] = v_full + else: + local_end_index = kv_cache["local_end_index"].item() + current_end - kv_cache["global_end_index"].item() + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key_full + kv_cache["v"][:, local_start_index:local_end_index] = v_full + + # Step 4: chunk back to SP distribution for attention computation + roped_query = torch.chunk(roped_query_full, sp_size, dim=1)[sp_rank] + + # Step 5: compute attention using xFuserLongContextAttention for sequence parallelism + # Chunk KV cache window to match SP distribution + kv_k_full = kv_cache["k"][:, max(0, local_end_index - self.max_attention_size):local_end_index] + kv_v_full = kv_cache["v"][:, max(0, local_end_index - self.max_attention_size):local_end_index] + kv_k = torch.chunk(kv_k_full, sp_size, dim=1)[sp_rank] + kv_v = torch.chunk(kv_v_full, sp_size, dim=1)[sp_rank] + + x = xFuserLongContextAttention()( + None, + query=half(roped_query), + key=half(kv_k), + value=kv_v, + window_size=self.window_size + ) + + kv_cache["global_end_index"].fill_(current_end) + kv_cache["local_end_index"].fill_(local_end_index) + + # Output projection + x = x.flatten(2) + x = self.o(x) return x \ No newline at end of file diff --git a/videox_fun/models/wan_transformer3d_self_forcing.py b/videox_fun/models/wan_transformer3d_self_forcing.py index e0e1de38..6de2e3d2 100644 --- a/videox_fun/models/wan_transformer3d_self_forcing.py +++ b/videox_fun/models/wan_transformer3d_self_forcing.py @@ -11,6 +11,7 @@ import numpy as np import torch +import torch._dynamo as dynamo import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config @@ -23,14 +24,11 @@ from ..dist import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group, - usp_attn_forward, xFuserLongContextAttention) -from ..utils import cfg_skip + usp_attn_self_forcing_forward, xFuserLongContextAttention) from .attention_utils import attention -from .wan_camera_adapter import SimpleAdapter from .wan_transformer3d import (MLPProj, WanLayerNorm, WanRMSNorm, WanTransformer3DModel, rope_apply, rope_params, sinusoidal_embedding_1d) -import torch._dynamo as dynamo if dynamo.config.cache_size_limit < 128: dynamo.config.cache_size_limit = 128 @@ -42,9 +40,6 @@ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") -logger = logging.get_logger(__name__) - - @amp.autocast(enabled=False) @torch.compiler.disable() def causal_rope_apply(x, grid_sizes, freqs, start_frame=0): @@ -203,13 +198,20 @@ def qkv_fn(x): ) # Apply flex attention with block mask - x = flex_attention( - query=padded_roped_query.transpose(2, 1), - key=padded_roped_key.transpose(2, 1), - value=padded_v.transpose(2, 1), - block_mask=block_mask - )[:, :, :-padded_length].transpose(2, 1) - + if padded_length != 0: + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask + )[:, :, :-padded_length].transpose(2, 1) + else: + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask + ).transpose(2, 1) else: # Standard inference without teacher forcing roped_query = rope_apply(q, grid_sizes, freqs).type_as(v) @@ -282,6 +284,7 @@ def qkv_fn(x): local_start_index = local_end_index - num_new_tokens kv_cache["k"][:, local_start_index:local_end_index] = roped_key kv_cache["v"][:, local_start_index:local_end_index] = v + # Compute attention with local window x = attention( roped_query, @@ -708,6 +711,17 @@ def __init__( self.sp_world_rank = 0 self.init_weights() + def enable_multi_gpus_inference(self): + """Enable multi-GPU inference with sequence parallelism for KV cache mode.""" + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + + # Replace self_attn forward method with USP version + for block in self.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_self_forcing_forward, block.self_attn) + def _set_gradient_checkpointing(self, *args, **kwargs): if "value" in kwargs: self.gradient_checkpointing = kwargs["value"] @@ -789,7 +803,7 @@ def attention_mask(b, h, q_idx, kv_idx): H=None, Q_LEN=total_length + padded_length, KV_LEN=total_length + padded_length, - _compile=False, + _compile=True, device=device ) @@ -877,18 +891,19 @@ def attention_mask(b, h, q_idx, kv_idx): eye_mask = q_idx == kv_idx return eye_mask | clean_mask | noise_mask - - block_mask = create_block_mask( + + self.block_mask = create_block_mask( attention_mask, B=None, H=None, Q_LEN=total_length + padded_length, KV_LEN=total_length + padded_length, - _compile=False, + _compile=True, device=device ) - - return block_mask + + # Store parameters for future reference + self.num_frame_per_block = num_frame_per_block def forward( self, @@ -945,33 +960,28 @@ def forward( grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] - - # Handle teacher forcing: concatenate clean and noisy features - if clean_x is not None: - clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x] - clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - logger.debug(f"[SelfForcing] seq_len: {seq_len}, seq_lens: {seq_lens}") + # Padding for multi-gpu inference + if self.sp_world_size > 1: + seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size assert seq_lens.max() <= seq_len - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_lens[0] - u.size(1), u.size(2))], - dim=1) for u in x - ]) - + x = torch.cat(x) + # Concatenate clean features for teacher forcing if clean_x is not None: + # Handle teacher forcing: concatenate clean and noisy features + clean_x = [self.patch_embedding(u.unsqueeze(0)) for u in clean_x] + clean_x = [u.flatten(2).transpose(1, 2) for u in clean_x] seq_lens_clean = torch.tensor([u.size(1) for u in clean_x], dtype=torch.long) assert seq_lens_clean.max() <= seq_len - clean_x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_lens_clean[0] - u.size(1), u.size(2))], dim=1) for u in clean_x - ]) + clean_x = torch.cat(clean_x) x = torch.cat([clean_x, x], dim=1) # Manage block mask for training (kv_cache is None during training). # We must recreate the mask when switching between normal training # (causal mask) and teacher forcing (clean+noisy mask), because the # two modes expect different sequence lengths. + num_frames_actual = None if kv_cache is None: num_frames_actual = grid_sizes[0, 0].item() frame_seqlen_actual = grid_sizes[0, 1].item() * grid_sizes[0, 2].item() @@ -990,11 +1000,11 @@ def forward( if is_teacher_forcing_mask: if self.independent_first_frame: raise NotImplementedError("Teacher forcing with independent first frame is not supported") - self.block_mask = self.create_teacher_forcing_mask( - device=device, + self.create_teacher_forcing_mask( num_frames=num_frames_actual, frame_seqlen=frame_seqlen_actual, - num_frame_per_block=self.num_frame_per_block + num_frame_per_block=self.num_frame_per_block, + device=device, ) else: self.create_block_mask_for_training( @@ -1046,6 +1056,12 @@ def forward( context_clip = self.img_emb(clip_fea) # Shape: [B, 257, dim] context = torch.concat([context_clip, context], dim=1) + # Context Parallel: split input across GPUs + if self.sp_world_size > 1: + x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank] + if t.dim() != 1: + e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank] + # Arguments kwargs = dict( e=e0, @@ -1091,6 +1107,10 @@ def custom_forward(*inputs, **kwargs): if clean_x is not None: x = x[:, x.shape[1] // 2:] + # Context Parallel: gather results from all GPUs + if self.sp_world_size > 1: + x = self.all_gather(x, dim=1) + # Head: project to output space x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2)) # Unpatchify: reconstruct video from patches diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py index 174cc9db..43688ee1 100644 --- a/videox_fun/pipeline/pipeline_wan_self_forcing.py +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -14,7 +14,7 @@ from diffusers.video_processor import VideoProcessor from ..models import (AutoencoderKLWan, AutoTokenizer, - WanT5EncoderModel, WanTransformer3DModel) + WanT5EncoderModel, WanTransformer3DModel_SelfForcing) from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas) from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -127,7 +127,7 @@ def __init__( tokenizer: AutoTokenizer, text_encoder: WanT5EncoderModel, vae: AutoencoderKLWan, - transformer: WanTransformer3DModel, + transformer: WanTransformer3DModel_SelfForcing, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() From 934469e15f396a02c9e2d1eb2e2efede2dec47c2 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 12:12:23 +0800 Subject: [PATCH 07/16] Update self-forcing training --- scripts/wan2.1_self_forcing/train_distill.py | 301 +++++++++++-------- 1 file changed, 171 insertions(+), 130 deletions(-) diff --git a/scripts/wan2.1_self_forcing/train_distill.py b/scripts/wan2.1_self_forcing/train_distill.py index ecf065e2..e219cf4c 100644 --- a/scripts/wan2.1_self_forcing/train_distill.py +++ b/scripts/wan2.1_self_forcing/train_distill.py @@ -1214,7 +1214,7 @@ def load_model_hook(models, input_dir): args.random_hw_adapt = False # Get the dataset - if args.train_mode != "normal": + if args.train_mode != "normal" or args.use_teacher_forcing: train_dataset = ImageVideoDataset( args.train_data_meta, args.train_data_dir, video_sample_size=args.video_sample_size, video_sample_stride=args.video_sample_stride, video_sample_n_frames=args.video_sample_n_frames, @@ -1242,7 +1242,7 @@ def get_length_to_frame_num(token_length): return length_to_frame_num - if args.enable_bucket and args.train_mode != "normal": + if args.enable_bucket: aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} batch_sampler_generator = torch.Generator().manual_seed(args.seed) batch_sampler = AspectRatioBatchImageVideoSampler( @@ -1689,7 +1689,7 @@ def unwrap_model(model): save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/mask_{gif_name[:10] if not text == '' else f'{global_step}-{idx}'}.mp4", rescale=True) with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): - if args.train_mode != "normal": + if args.train_mode != "normal" or args.use_teacher_forcing: # Convert images to latent space pixel_values = batch["pixel_values"].to(weight_dtype) @@ -1713,20 +1713,20 @@ def unwrap_model(model): batch['neg_encoder_attention_mask'] = torch.tile(batch['neg_encoder_attention_mask'], (2, 1)) else: batch['text'] = batch['text'] * 2 - - clip_pixel_values = batch["clip_pixel_values"].to(weight_dtype) - mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype) - mask = batch["mask"].to(weight_dtype) - # Increase the batch size when the length of the latent sequence of the current sample is small - if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: - if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: - clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) - mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1)) - mask = torch.tile(mask, (4, 1, 1, 1, 1)) - elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: - clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) - mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1)) - mask = torch.tile(mask, (2, 1, 1, 1, 1)) + if args.train_mode != "normal": + clip_pixel_values = batch["clip_pixel_values"].to(weight_dtype) + mask_pixel_values = batch["mask_pixel_values"].to(weight_dtype) + mask = batch["mask"].to(weight_dtype) + # Increase the batch size when the length of the latent sequence of the current sample is small + if args.auto_tile_batch_size and args.training_with_video_token_length and zero_stage != 3: + if args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 16 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (4, 1, 1, 1)) + mask_pixel_values = torch.tile(mask_pixel_values, (4, 1, 1, 1, 1)) + mask = torch.tile(mask, (4, 1, 1, 1, 1)) + elif args.video_sample_n_frames * args.token_sample_size * args.token_sample_size // 4 >= pixel_values.size()[1] * pixel_values.size()[3] * pixel_values.size()[4]: + clip_pixel_values = torch.tile(clip_pixel_values, (2, 1, 1, 1)) + mask_pixel_values = torch.tile(mask_pixel_values, (2, 1, 1, 1, 1)) + mask = torch.tile(mask, (2, 1, 1, 1, 1)) if args.random_frame_crop: def _create_special_list(length): @@ -1783,7 +1783,8 @@ def _create_special_list(length): if args.low_vram: torch.cuda.empty_cache() vae.to(accelerator.device) - clip_image_encoder.to(accelerator.device) + if args.train_mode != "normal": + clip_image_encoder.to(accelerator.device) real_score_transformer3d = real_score_transformer3d.to("cpu") if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") @@ -1800,39 +1801,47 @@ def _batch_encode_vae(pixel_values): pixel_values_bs = pixel_values_bs.sample() new_pixel_values.append(pixel_values_bs) return torch.cat(new_pixel_values, dim = 0) - - # Encode inpaint latents. - mask_latents = _batch_encode_vae(mask_pixel_values) - if vae_stream_2 is not None: - torch.cuda.current_stream().wait_stream(vae_stream_2) - - # Encode clean latents for teacher forcing - clean_latents = None if args.use_teacher_forcing: - clean_latents = _batch_encode_vae(pixel_values) - - mask = rearrange(mask, "b f c h w -> b c f h w") - mask = torch.concat( - [ - torch.repeat_interleave(mask[:, :, 0:1], repeats=4, dim=2), - mask[:, :, 1:] - ], dim=2 - ) - mask = mask.view(mask.shape[0], mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]) - mask = mask.transpose(1, 2) - mask = resize_mask(1 - mask, mask_latents) + clean_latents = _batch_encode_vae(pixel_values) + else: + clean_latents = None + + if args.train_mode != "normal": + # Encode inpaint latents. + mask_latents = _batch_encode_vae(mask_pixel_values) + if vae_stream_2 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_2) + + # Encode clean latents for teacher forcing + clean_latents = None + if args.use_teacher_forcing: + clean_latents = _batch_encode_vae(pixel_values) + + mask = rearrange(mask, "b f c h w -> b c f h w") + mask = torch.concat( + [ + torch.repeat_interleave(mask[:, :, 0:1], repeats=4, dim=2), + mask[:, :, 1:] + ], dim=2 + ) + mask = mask.view(mask.shape[0], mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]) + mask = mask.transpose(1, 2) + mask = resize_mask(1 - mask, mask_latents) - inpaint_latents = torch.concat([mask, mask_latents], dim=1) + inpaint_latents = torch.concat([mask, mask_latents], dim=1) - clip_context = [] - for clip_pixel_value in clip_pixel_values: - clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy())) - clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(clip_image_encoder.device, weight_dtype) - _clip_context = clip_image_encoder([clip_image[:, None, :, :]]) - clip_context.append(_clip_context) - clip_context = torch.cat(clip_context) + clip_context = [] + for clip_pixel_value in clip_pixel_values: + clip_image = Image.fromarray(np.uint8(clip_pixel_value.float().cpu().numpy())) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(clip_image_encoder.device, weight_dtype) + _clip_context = clip_image_encoder([clip_image[:, None, :, :]]) + clip_context.append(_clip_context) + clip_context = torch.cat(clip_context) - target_shape = mask_latents.size() + if args.use_teacher_forcing: + target_shape = clean_latents.size() + else: + target_shape = mask_latents.size() else: text = batch['text'] if args.fix_sample_size is not None: @@ -1874,21 +1883,7 @@ def _batch_encode_vae(pixel_values): int(local_sample_size[0] // vae.spatial_compression_ratio), int(local_sample_size[1] // vae.spatial_compression_ratio), ) - - # Encode clean latents for teacher forcing in T2V mode clean_latents = None - if args.use_teacher_forcing: - with torch.no_grad(): - pixel_values = batch["pixel_values"].to(weight_dtype) - pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") - bs = args.vae_mini_batch - new_pixel_values = [] - for i in range(0, pixel_values.shape[0], bs): - pixel_values_bs = pixel_values[i : i + bs] - pixel_values_bs = vae.encode(pixel_values_bs)[0] - pixel_values_bs = pixel_values_bs.sample() - new_pixel_values.append(pixel_values_bs) - clean_latents = torch.cat(new_pixel_values, dim = 0) if args.low_vram: vae.to('cpu') @@ -2008,26 +2003,11 @@ def convert_flow_pred_to_x0( # --- Main Training Logic --- bsz, channel, num_frames, height, width = target_shape if step % args.gen_update_interval == 0: - # Self-Forcing training: create block mask for causal training - patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] - - # frame_seqlen: tokens per frame AFTER VAE compression and patching - # VAE compresses 8x, then patches are extracted with patch_size - frame_seqlen = (height * width) // (patch_h * patch_w) - - # Create block mask if not exists or parameters changed - accelerator.unwrap_model(generator_transformer3d).create_block_mask_for_training( - num_frames=num_frames, - frame_seqlen=frame_seqlen, - num_frame_per_block=args.num_frame_per_block, - independent_first_frame=args.independent_first_frame, - device=accelerator.device - ) - if args.use_kv_cache_training: # === KV cache block-by-block training (original Self-Forcing) === # Calculate frame_seq_length + patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] frame_seq_length = (target_shape[3] * target_shape[4]) // (patch_h * patch_w) # Determine block structure @@ -2081,28 +2061,15 @@ def convert_flow_pred_to_x0( # Decide whether to use teacher forcing for this video (once per video, not per block) use_teacher_forcing_step = ( args.use_teacher_forcing and - torch.rand(1, generator=torch_rng).item() < args.teacher_forcing_prob + torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob ) - # Prepare clean_x and aug_t for teacher forcing - clean_x = None - aug_t = None - if use_teacher_forcing_step and clean_latents is not None: - aug_t = torch.zeros(bsz, device=accelerator.device, dtype=torch.int64) - for block_idx, current_num_frames in enumerate(all_num_frames): # Extract noise for current block start_idx = current_start_frame - num_input_frames end_idx = start_idx + current_num_frames noisy_input = generator_noise[:, :, start_idx:end_idx] - # Extract clean latents for current block if using teacher forcing - if use_teacher_forcing_step and clean_latents is not None: - clean_x_block = clean_latents[:, :, start_idx:end_idx] - clean_x = [clean_x_block[i] for i in range(bsz)] - else: - clean_x = None - # Denoise loop for current block num_denoising_steps = len(denoising_step_list) final_step_index = generate_and_sync_list(num_denoising_steps, device=noisy_input.device)[0] @@ -2123,7 +2090,7 @@ def convert_flow_pred_to_x0( noisy_input_list = [noisy_input[i] for i in range(bsz)] # Use full seq_len (consistent with inference code) - full_seq_len = frame_seqlen * num_frames + full_seq_len = frame_seq_length * num_frames generator_pred_block = generator_transformer3d( x=noisy_input_list, @@ -2136,8 +2103,6 @@ def convert_flow_pred_to_x0( cache_start=None, y=inpaint_latents if args.train_mode != "normal" else None, clip_fea=clip_context if args.train_mode != "normal" else None, - clean_x=clean_x, - aug_t=aug_t, ) # Stack list output to tensor: [B, C, F, H, W] @@ -2168,25 +2133,28 @@ def convert_flow_pred_to_x0( # Record output output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = generator_pred_block - # Update KV cache with context noise + # Update KV cache with clean context (teacher forcing) or noisy context if block_idx < len(all_num_frames) - 1: context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise - # Add context noise - generator_pred_block_noisy = add_noise( - generator_pred_block, - torch.randn(generator_pred_block.shape, dtype=generator_pred_block.dtype, device=generator_pred_block.device, generator=torch_rng), - context_timestep[:, 0] - ) + # Use clean latents for teacher forcing, otherwise add noise + if use_teacher_forcing_step and clean_latents is not None: + context_input = clean_latents[:, :, start_idx:end_idx] + else: + context_input = add_noise( + generator_pred_block, + torch.randn(generator_pred_block.shape, dtype=generator_pred_block.dtype, device=generator_pred_block.device, generator=torch_rng), + context_timestep[:, 0] + ) - generator_pred_block_noisy_list = [generator_pred_block_noisy[i] for i in range(bsz)] + context_input_list = [context_input[i] for i in range(bsz)] # Use full seq_len (consistent with inference code) - full_seq_len = frame_seqlen * num_frames + full_seq_len = frame_seq_length * num_frames with torch.no_grad(): generator_transformer3d( - x=generator_pred_block_noisy_list, + x=context_input_list, context=prompt_embeds, t=context_timestep, seq_len=full_seq_len, @@ -2202,30 +2170,52 @@ def convert_flow_pred_to_x0( # Final output generator_pred = output_pred - seq_len = frame_seqlen * num_frames # For fake/real score computation - + seq_len = frame_seq_length * num_frames # For fake/real score computation + else: - # === Original block mask training === + # === Block mask training (flex attention, no KV cache) === + # Block mask training: use flex attention to process entire video at once + + patch_h_bm, patch_w_bm = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] + frame_seqlen_bm = (height * width) // (patch_h_bm * patch_w_bm) + # Standard backward simulation training generator_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) num_denoising_steps = len(denoising_step_list) final_step_index = generate_and_sync_list(num_denoising_steps, device=generator_noise.device)[0] # Precompute seq_len once (same for all steps) - seq_len = frame_seqlen * num_frames + seq_len = frame_seqlen_bm * num_frames # Decide whether to use teacher forcing for this step use_teacher_forcing_step = ( args.use_teacher_forcing and - torch.rand(1, generator=torch_rng).item() < args.teacher_forcing_prob + torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob ) - # Prepare clean_x and aug_t for teacher forcing - clean_x = None - aug_t = None + # Create appropriate block mask based on teacher forcing decision if use_teacher_forcing_step and clean_latents is not None: + # Teacher forcing: clean + noisy sequence mask + accelerator.unwrap_model(generator_transformer3d).create_teacher_forcing_mask( + device=accelerator.device, + num_frames=num_frames, + frame_seqlen=frame_seqlen_bm, + num_frame_per_block=args.num_frame_per_block, + ) + # Prepare clean_x and aug_t for teacher forcing clean_x = [clean_latents[i] for i in range(clean_latents.size(0))] aug_t = torch.zeros(bsz, device=accelerator.device, dtype=torch.int64) + else: + # Standard causal mask + accelerator.unwrap_model(generator_transformer3d).create_block_mask_for_training( + num_frames=num_frames, + frame_seqlen=frame_seqlen_bm, + num_frame_per_block=args.num_frame_per_block, + independent_first_frame=args.independent_first_frame, + device=accelerator.device + ) + clean_x = None + aug_t = None for index, current_timestep in enumerate(denoising_step_list): is_final_step = (index == final_step_index) @@ -2240,15 +2230,19 @@ def convert_flow_pred_to_x0( context_manager = torch.no_grad() if not is_final_step else contextlib.nullcontext() with context_manager: + # Convert to list format for transformer + generator_noise_list = [generator_noise[i] for i in range(bsz)] + clean_x_list = [clean_latents[i] for i in range(bsz)] if clean_x is not None else None + # Use block_mask for causal training (一次性处理整个视频) generator_pred = generator_transformer3d( - x=generator_noise, + x=generator_noise_list, context=prompt_embeds, t=timestep, seq_len=seq_len, y=inpaint_latents if args.train_mode != "normal" else None, clip_fea=clip_context if args.train_mode != "normal" else None, - clean_x=clean_x, + clean_x=clean_x_list, aug_t=aug_t, ) generator_pred = convert_flow_pred_to_x0( @@ -2435,6 +2429,12 @@ def convert_flow_pred_to_x0( num_input_frames = 0 output_pred = torch.zeros_like(fake_score_critic_noise) + # Decide whether to use teacher forcing for this video + use_teacher_forcing_step = ( + args.use_teacher_forcing and + torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob + ) + for block_idx, current_num_frames in enumerate(all_num_frames): start_idx = current_start_frame - num_input_frames end_idx = start_idx + current_num_frames @@ -2458,7 +2458,7 @@ def convert_flow_pred_to_x0( noisy_input_list = [noisy_input[i] for i in range(bsz)] # Use full seq_len (consistent with inference code) - full_seq_len = frame_seqlen * num_frames + full_seq_len = frame_seq_length * num_frames fake_score_denoised_pred_block = generator_transformer3d( x=noisy_input_list, @@ -2498,23 +2498,28 @@ def convert_flow_pred_to_x0( output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = fake_score_denoised_pred_block - # Update KV cache + # Update KV cache with clean context (teacher forcing) or noisy context if block_idx < len(all_num_frames) - 1: context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise - fake_score_denoised_pred_noisy = add_noise( - fake_score_denoised_pred_block, - torch.randn(fake_score_denoised_pred_block.shape, dtype=fake_score_denoised_pred_block.dtype, device=fake_score_denoised_pred_block.device, generator=torch_rng), - context_timestep[:, 0] - ) - fake_score_denoised_pred_noisy_list = [fake_score_denoised_pred_noisy[i] for i in range(bsz)] + # Use clean latents for teacher forcing, otherwise add noise + if use_teacher_forcing_step and clean_latents is not None: + context_input = clean_latents[:, :, start_idx:end_idx] + else: + context_input = add_noise( + fake_score_denoised_pred_block, + torch.randn(fake_score_denoised_pred_block.shape, dtype=fake_score_denoised_pred_block.dtype, device=fake_score_denoised_pred_block.device, generator=torch_rng), + context_timestep[:, 0] + ) + + context_input_list = [context_input[i] for i in range(bsz)] # Use full seq_len (consistent with inference code) - full_seq_len = frame_seqlen * num_frames + full_seq_len = frame_seq_length * num_frames with torch.no_grad(): generator_transformer3d( - x=fake_score_denoised_pred_noisy_list, + x=context_input_list, context=prompt_embeds, t=context_timestep, seq_len=full_seq_len, @@ -2532,14 +2537,45 @@ def convert_flow_pred_to_x0( seq_len = frame_seq_length * num_frames else: - # Original block mask mode with torch.no_grad(): + # Block mask mode: use flex attention to process entire video at once + + patch_h_bm, patch_w_bm = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] + frame_seqlen_bm = (height * width) // (patch_h_bm * patch_w_bm) + seq_len = frame_seqlen_bm * num_frames + fake_score_critic_noise = torch.randn(target_shape, device=accelerator.device, generator=torch_rng, dtype=weight_dtype) num_denoising_steps = len(denoising_step_list) final_step_index = generate_and_sync_list(num_denoising_steps, device=fake_score_critic_noise.device)[0] - patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] - seq_len = math.ceil((width * height) / (patch_h * patch_w) * num_frames) + # Decide whether to use teacher forcing for this step + use_teacher_forcing_step = ( + args.use_teacher_forcing and + torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob + ) + + # Create appropriate block mask based on teacher forcing decision + if use_teacher_forcing_step and clean_latents is not None: + # Teacher forcing: clean + noisy sequence mask + accelerator.unwrap_model(generator_transformer3d).create_teacher_forcing_mask( + device=accelerator.device, + num_frames=num_frames, + frame_seqlen=frame_seqlen_bm, + num_frame_per_block=args.num_frame_per_block, + ) + clean_x = [clean_latents[i] for i in range(clean_latents.size(0))] + aug_t = torch.zeros(bsz, device=accelerator.device, dtype=torch.int64) + else: + # Standard causal mask + accelerator.unwrap_model(generator_transformer3d).create_block_mask_for_training( + num_frames=num_frames, + frame_seqlen=frame_seqlen_bm, + num_frame_per_block=args.num_frame_per_block, + independent_first_frame=args.independent_first_frame, + device=accelerator.device + ) + clean_x = None + aug_t = None for index, current_timestep in enumerate(denoising_step_list): is_final_step = (index == final_step_index) @@ -2550,16 +2586,21 @@ def convert_flow_pred_to_x0( dtype=torch.int64 ) + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + # Convert to list format for transformer + fake_score_critic_noise_list = [fake_score_critic_noise[i] for i in range(bsz)] + clean_x_list = [clean_latents[i] for i in range(bsz)] if clean_x is not None else None + fake_score_denoised_pred = generator_transformer3d( - x=fake_score_critic_noise, + x=fake_score_critic_noise_list, context=prompt_embeds, t=timestep, seq_len=seq_len, y=inpaint_latents if args.train_mode != "normal" else None, clip_fea=clip_context if args.train_mode != "normal" else None, - clean_x=None, - aug_t=None, + clean_x=clean_x_list, + aug_t=aug_t, ) fake_score_denoised_pred = convert_flow_pred_to_x0( scheduler=noise_scheduler, From fd5bd79d771b663c4bcb5b31c06dc3bb3ceebaad Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 12:15:21 +0800 Subject: [PATCH 08/16] Update isort --- videox_fun/pipeline/pipeline_wan_self_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py index 43688ee1..5375c010 100644 --- a/videox_fun/pipeline/pipeline_wan_self_forcing.py +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -13,8 +13,8 @@ from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from ..models import (AutoencoderKLWan, AutoTokenizer, - WanT5EncoderModel, WanTransformer3DModel_SelfForcing) +from ..models import (AutoencoderKLWan, AutoTokenizer, WanT5EncoderModel, + WanTransformer3DModel_SelfForcing) from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas) from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler From 65ae864f6f5c229dac0e4516bcb2d70ea3fb74ff Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 12:17:35 +0800 Subject: [PATCH 09/16] Update sh --- scripts/wan2.1_self_forcing/train_distill.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/wan2.1_self_forcing/train_distill.sh b/scripts/wan2.1_self_forcing/train_distill.sh index dfdb5290..f324652d 100644 --- a/scripts/wan2.1_self_forcing/train_distill.sh +++ b/scripts/wan2.1_self_forcing/train_distill.sh @@ -41,4 +41,5 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --uniform_sampling \ --train_mode="normal" \ --trainable_modules "." \ + --use_teacher_forcing \ --low_vram From 55ef1da66aa2a5b6e64ca73582691e50f27bb914 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 15:00:47 +0800 Subject: [PATCH 10/16] Update progress_bar and READMES --- .../README_TRAIN_DISTILL.md | 716 +++++++++++++++++ .../README_TRAIN_DISTILL_zh-CN.md | 717 ++++++++++++++++++ .../pipeline/pipeline_wan_self_forcing.py | 173 ++--- 3 files changed, 1521 insertions(+), 85 deletions(-) create mode 100755 scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md create mode 100755 scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md new file mode 100755 index 00000000..9c8dfd09 --- /dev/null +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md @@ -0,0 +1,716 @@ +# Wan2.1 Self-Forcing Distillation Training Guide + +This document provides a complete workflow for Self-Forcing distillation of Wan2.1 including environment setup, data preparation, distributed training, and inference testing. + +> **Note**: Wan2.1 Self-Forcing is a causal video generation model that supports text-to-video (T2V). Combined with distillation, this training code can reduce inference steps from 25-50 to 4-8 steps while enabling block-by-block causal generation with teacher forcing. + +--- + +## Table of Contents +- [1. Environment Setup](#1-environment-setup) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Distillation Training](#3-distillation-training) + - [3.1 Download Pretrained Models](#31-download-pretrained-models) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training Validation](#34-training-validation) + - [3.5 Training with FSDP](#35-training-with-fsdp) + - [3.6 Other Backends](#36-other-backends) + - [3.7 Multi-Node Distributed Training](#37-multi-node-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameters](#41-inference-parameters) + - [4.2 Text-to-Video (T2V) Inference](#42-text-to-video-t2v-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) + +--- + +## 1. Environment Setup + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that your machine has correctly installed GPU drivers and CUDA environment, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset that contains several training data samples. + +```bash +# Download official example dataset +modelscope download --dataset PAI/X-Fun-Videos-Demo --local_dir ./datasets/X-Fun-Videos-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 video001.mp4 +│ │ ├── 📄 video002.mp4 +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example format): +```json +[ + { + "file_path": "train/video001.mp4", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "type": "video", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/video002.mp4", + "text": "A person walking through a forest, cinematic view", + "type": "video", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/videos/sunset.mp4", + "text": "A beautiful sunset over the ocean", + "type": "video", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Field Descriptions**: +- `file_path`: Video path (relative or absolute path) +- `text`: Video description (English prompt) +- `type`: Data type, fixed as `"video"` +- `width` / `height`: Video dimensions (**recommended** to provide for bucket training. If not provided, it will be automatically read during training, which may affect training speed when data is stored on slower systems like OSS). + - You can use `scripts/process_json_add_width_and_height.py` to extract width and height fields for JSON files without these fields, supporting both images and videos. + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Videos-Demo/metadata.json --output_file datasets/X-Fun-Videos-Demo/metadata_add_width_height.json`. + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure in the training script: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure in the training script: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (such as NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. Distillation Training + +### 3.1 Download Pretrained Models + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download Wan2.1 official weights +# T2V model (text-to-video) +modelscope download --model Wan-AI/Wan2.1-T2V-1.3B --local_dir models/Diffusion_Transformer/Wan2.1-T2V-1.3B + +# Self-Forcing +hf download gdhe17/Self-Forcing --local-dir models/Diffusion_Transformer/Self-Forcing +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +After downloading data according to **2.1 Quick Test Dataset** and downloading weights according to **3.1 Download Pretrained Models**, you can directly copy and run the quick start command. + +We recommend using DeepSpeed-Zero-2 and FSDP for training. Here we use DeepSpeed-Zero-2 as an example to configure the shell file. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether to shard model weights. **If you use multiple GPUs and encounter insufficient GPU memory with DeepSpeed-Zero-2**, you can switch to FSDP for training. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.3 Common Training Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Pretrained model path | `models/Diffusion_Transformer/Wan2.1-T2V-1.3B/` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Batch size per GPU | 1 | +| `--image_sample_size` | Maximum image training resolution | 640 | +| `--video_sample_size` | Maximum video training resolution | 640 | +| `--token_sample_size` | Token sample size | 640 | +| `--video_sample_stride` | Video sampling stride | 2 | +| `--video_sample_n_frames` | Number of video frames | 81 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (effectively increases batch) | 1 | +| `--dataloader_num_workers` | DataLoader worker processes | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate (generator) | 2e-06 | +| `--learning_rate_critic` | Initial learning rate (critic) | 2e-07 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir_wan2.1_self_forcing_distill` | +| `--gradient_checkpointing` | Enable gradient checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon | 1e-10 | +| `--vae_mini_batch` | VAE encoding mini-batch size | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training, no cropping, group by resolution | - | +| `--random_hw_adapt` | Auto-scale images/videos to random sizes in `[min_size, max_size]` range | - | +| `--training_with_video_token_length` | Train based on token length, supports arbitrary resolutions | - | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--low_vram` | Low VRAM mode | - | +| `--train_mode` | Training mode: `normal` (T2V) | `normal` | +| `--resume_from_checkpoint` | Resume training path, use `"latest"` to auto-select latest checkpoint | None | +| `--validation_steps` | Run validation every N steps | 2000 | +| `--validation_epochs` | Run validation every N epochs | 5 | +| `--validation_prompts` | Prompts for video generation validation | `"A dog shaking head..."` | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | + +**Distillation-Specific Parameters**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--denoising_step_indices_list` | Denoising step indices list (core distillation parameter) | `1000 750 500 250` | +| `--real_guidance_scale` | Real guidance scale for scoring | 6.0 | +| `--fake_guidance_scale` | Fake guidance scale for scoring | 0.0 | +| `--gen_update_interval` | Generator update interval | 5 | +| `--negative_prompt` | Negative prompt for distillation | Chinese negative prompt | +| `--train_sampling_steps` | Training sampling steps | 1000 | + +**Self-Forcing-Specific Parameters**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--fix_sample_size` | Fixed sample size `[height, width]` for training | `480 832` | +| `--num_frame_per_block` | Number of frames per block for causal training | 3 | +| `--independent_first_frame` | Whether first frame is independent (`[1, N, N, ...]` pattern) | - | +| `--use_kv_cache_training` | Use KV cache block-by-block training (matches original Self-Forcing) | - | +| `--context_noise` | Context noise level for KV cache update | 0 | +| `--use_teacher_forcing` | Enable teacher forcing training (pass clean_x to transformer) | - | +| `--teacher_forcing_prob` | Probability of applying teacher forcing per step | 1.0 | + +**Sample Size Configuration Guide**: +- `video_sample_size` represents the resolution size of videos; when `random_hw_adapt` is True, it represents the minimum value between video and image resolutions. +- `image_sample_size` represents the resolution size of images; when `random_hw_adapt` is True, it represents the maximum value between video and image resolutions. +- `token_sample_size` represents the resolution corresponding to the maximum token length when `training_with_video_token_length` is True. +- Due to potential confusion in configuration, **if you don't require arbitrary resolution for finetuning**, it is recommended to set `video_sample_size`, `image_sample_size`, and `token_sample_size` to the same fixed value, such as **(320, 480, 512, 640, 960)**. + - **All set to 320** represents **240P**. + - **All set to 480** represents **320P**. + - **All set to 640** represents **480P**. + - **All set to 960** represents **720P**. + +**Token Length Training Guide**: +- When `training_with_video_token_length` is enabled, the model trains based on token length. +- For example: A video with 512x512 resolution and 49 frames has a token length of 13,312, requiring `token_sample_size = 512`. + - At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512). + - At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768). + - At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024). + - These resolutions combined with their corresponding frame counts allow the model to generate videos of different sizes. + +### 3.4 Training Validation + +You can configure validation parameters to periodically generate test videos during training, allowing you to monitor training progress and model quality. + +**Validation Parameter Descriptions**: + +| Parameter | Description | Recommended Value | +|------|------|--------| +| `--validation_steps` | Run validation every N steps | 2000 | +| `--validation_epochs` | Run validation every N epochs | 5 | +| `--validation_prompts` | Prompts for video generation validation | English prompts | + +**T2V Validation Example**: + +```bash + --validation_steps=2000 \ + --validation_epochs=5 \ + --validation_prompts="A brown dog shaking its head, sitting on a light-colored sofa in a cozy room. Behind the dog, there's a framed picture on the shelf, surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere." +``` + +**Notes**: +- Validation videos will be saved to the `output_dir` directory +- Multi-prompt validation format: `--validation_prompts "prompt1" "prompt2" "prompt3"` + +### 3.5 Training with FSDP + +**If you use multiple GPUs and encounter insufficient GPU memory with DeepSpeed-Zero-2**, you can switch to FSDP for training. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.6 Other Backends + +#### 3.6.1 Training with DeepSpeed-Zero-3 + +DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable. + +DeepSpeed Zero-3 is suitable for 14B Wan at high resolutions. After training, you can use the following command to get the final model: +```bash +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +Training shell command is as follows: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +#### 3.6.2 Training without DeepSpeed and FSDP + +**This approach is not recommended because there is no memory-saving backend, which can easily cause out-of-memory errors**. We only provide the training shell for reference. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.7 Multi-Node Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.7.1 Environment Configuration + +Assuming 2 machines, each with 8 GPUs: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Rank of this machine (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.7.2 Multi-Node Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +--- + +## 4. Inference Testing + +### 4.1 Inference Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory mode, see options below | `sequential_cpu_offload` | +| `ulysses_degree` | Ulysses parallelism degree for multi-GPU inference | 1 | +| `ring_degree` | Ring parallelism degree for multi-GPU inference | 1 | +| `fsdp_dit` | Use FSDP for Transformer during multi-GPU inference to save memory | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder during multi-GPU inference | `True` | +| `compile_dit` | Compile Transformer for faster inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/Wan2.1-T2V-1.3B` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Trained Transformer weight path | `"models/Diffusion_Transformer/Self-Forcing/checkpoints/self_forcing_dmd.pt"` | +| `vae_path` | Trained VAE weight path | `None` | +| `lora_path` | LoRA weight path | `None` | +| `sample_size` | Generated video resolution `[height, width]` | `[480, 832]` | +| `video_length` | Number of frames to generate | `81` | +| `fps` | Frames per second | `16` | +| `weight_dtype` | Model weight dtype, use `torch.float16` for GPUs that don't support bf16 | `torch.bfloat16` | +| `num_frame_per_block` | Number of frames to generate per block (1 for standard causal, higher for faster but more memory) | 3 | +| `local_attn_size` | Local attention window size (-1 for global attention) | -1 | +| `independent_first_frame` | Whether first frame is generated independently | `False` | +| `context_noise` | Context noise level for generation | 0.0 | +| `prompt` | Positive prompt describing what to generate | `"A stylish woman walks down a Tokyo street..."` | +| `negative_prompt` | Negative prompt to avoid certain content | Chinese negative prompt | +| `guidance_scale` | Guidance strength (distillation models typically use 1.0) | 1.0 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Number of inference steps (typically 4 for distillation models) | 4 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Path to save generated videos | `samples/wan-videos-self-forcing-t2v` | + +**GPU Memory Mode Descriptions**: + +| Mode | Description | Memory Usage | +|------|------|---------| +| `model_full_load` | Entire model loaded to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer groups switch between CPU/CUDA | Low | +| `sequential_cpu_offload` | Layer-by-layer offload (slowest) | Lowest | + +### 4.2 Text-to-Video (T2V) Inference + +Run single GPU inference: + +```bash +python examples/wan2.1_self_forcing/predict_t2v.py +``` + +Edit `examples/wan2.1_self_forcing/predict_t2v.py` according to your needs. For first-time inference, focus on the following key parameters. For other parameters, please refer to the inference parameter descriptions above. + +```python +# Choose based on GPU memory +GPU_memory_mode = "sequential_cpu_offload" +# Your actual model path +model_name = "models/Diffusion_Transformer/Wan2.1-T2V-1.3B" +# Trained weight path +transformer_path = "models/Diffusion_Transformer/Self-Forcing/checkpoints/self_forcing_dmd.pt" +# Distillation models typically use 4 steps +num_inference_steps = 4 +# Distillation models guidance_scale is typically 1.0 +guidance_scale = 1.0 + +# Self-Forcing causal inference config +num_frame_per_block = 3 # Number of frames to generate per block +local_attn_size = -1 # Local attention window size (-1 for global attention) +independent_first_frame = False +context_noise = 0.0 + +# Write according to your generated content +prompt = "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/wan2.1_self_forcing/predict_t2v.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs used +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelism +ring_degree = 1 # Sequence dimension parallelism +``` + +**Configuration Principles**: +- `ulysses_degree` must be divisible by the model's head count +- `ring_degree` splits on the sequence dimension, which affects communication overhead. Try to avoid using it when heads are evenly divisible. + +**Configuration Examples**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelism | +| 8 | 8 | 1 | Head parallelism | +| 8 | 4 | 2 | Hybrid parallelism | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/wan2.1_self_forcing/predict_t2v.py +``` + +--- + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun \ No newline at end of file diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md new file mode 100755 index 00000000..f453e5fd --- /dev/null +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md @@ -0,0 +1,717 @@ +# Wan2.1 Self-Forcing 蒸馏训练指南 + +本文档提供了将 Wan2.1 进行 Self-Forcing 蒸馏的完整工作流,包括环境配置、数据准备、分布式训练和推理测试。 + +> **说明**:Wan2.1 Self-Forcing 是一个支持文生视频(T2V)的因果视频生成模型。结合蒸馏训练,该代码可以将推理步数从 25-50 步减少到 4-8 步,同时支持逐块因果生成与 teacher forcing。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、蒸馏训练](#三蒸馏训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 训练验证](#34-训练验证) + - [3.5 使用 FSDP 训练](#35-使用-fsdp-训练) + - [3.6 其他后端](#36-其他后端) + - [3.7 多机分布式训练](#37-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 文生视频(T2V)推理](#42-文生视频t2v推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# 拉取镜像 +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# 进入容器 +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Videos-Demo --local_dir ./datasets/X-Fun-Videos-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 video001.mp4 +│ │ ├── 📄 video002.mp4 +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/video001.mp4", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "type": "video", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/video002.mp4", + "text": "A person walking through a forest, cinematic view", + "type": "video", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/videos/sunset.mp4", + "text": "A beautiful sunset over the ocean", + "type": "video", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:视频路径(相对或绝对路径) +- `text`:视频描述(英文提示词) +- `type`:数据类型,固定为 `"video"` +- `width` / `height`:视频宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Videos-Demo/metadata.json --output_file datasets/X-Fun-Videos-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、蒸馏训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 Wan2.1 官方权重 +# T2V 模型(文生视频) +modelscope download --model Wan-AI/Wan2.1-T2V-1.3B --local_dir models/Diffusion_Transformer/Wan2.1-T2V-1.3B + +# Self-Forcing +hf download gdhe17/Self-Forcing --local-dir models/Diffusion_Transformer/Self-Forcing + +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用DeepSpeed-Zero-2与FSDP方案进行训练。这里使用DeepSpeed-Zero-2为例配置shell文件。 + +本文中DeepSpeed-Zero-2与FSDP的差别在于是否对模型权重进行分片,**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/Wan2.1-T2V-1.3B/` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 图像最大训练分辨率 | 640 | +| `--video_sample_size` | 视频最大训练分辨率 | 640 | +| `--token_sample_size` | Token 采样尺寸 | 640 | +| `--video_sample_stride` | 视频采样步幅 | 2 | +| `--video_sample_n_frames` | 视频采样帧数 | 81 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率(生成器) | 2e-06 | +| `--learning_rate_critic` | 初始学习率(判别器) | 2e-07 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir_wan2.1_self_forcing_distill` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码时的迷你批次大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用分桶训练,不裁剪图片/视频,按分辨率分组训练 | - | +| `--random_hw_adapt` | 自动缩放图片/视频到 `[min_size, max_size]` 范围内的随机尺寸 | - | +| `--training_with_video_token_length` | 根据 token 长度训练,支持任意分辨率 | - | +| `--uniform_sampling` | 均匀采样 timestep | - | +| `--low_vram` | 低显存模式 | - | +| `--train_mode` | 训练模式:`normal`(T2V) | `normal` | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--validation_steps` | 每 N 步执行一次验证 | 2000 | +| `--validation_epochs` | 每 N 个epoch执行一次验证 | 5 | +| `--validation_prompts` | 验证视频生成的提示词 | `"一只棕色的狗摇着头..."` | +| `--trainable_modules` | 可训练模块(`"."` 表示所有模块) | `"."` | + +**蒸馏特有参数**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--denoising_step_indices_list` | 去噪步骤列表(蒸馏核心参数) | `1000 750 500 250` | +| `--real_guidance_scale` | 用于评分的真实 guidance scale | 6.0 | +| `--fake_guidance_scale` | 用于评分的虚拟 guidance scale | 0.0 | +| `--gen_update_interval` | 生成器更新间隔 | 5 | +| `--negative_prompt` | 用于蒸馏的负向提示词 | 中文负向提示词 | +| `--train_sampling_steps` | 训练采样步数 | 1000 | + +**Self-Forcing 特有参数**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--fix_sample_size` | 固定训练尺寸 `[高度, 宽度]` | `480 832` | +| `--num_frame_per_block` | 每个块的帧数(用于因果训练) | 3 | +| `--independent_first_frame` | 第一帧是否独立生成(`[1, N, N, ...]` 模式) | - | +| `--use_kv_cache_training` | 使用 KV 缓存逐块训练(匹配原始 Self-Forcing) | - | +| `--context_noise` | KV 缓存更新的上下文噪声级别 | 0 | +| `--use_teacher_forcing` | 启用 teacher forcing 训练(将 clean_x 传给 transformer) | - | +| `--teacher_forcing_prob` | 每步应用 teacher forcing 的概率 | 1.0 | + +**Sample Size 配置指南**: +- `video_sample_size` 表示视频的分辨率大小;当 `random_hw_adapt` 为 True 时,表示视频和图像分辨率的最小值。 +- `image_sample_size` 表示图像的分辨率大小;当 `random_hw_adapt` 为 True 时,表示视频和图像分辨率的最大值。 +- `token_sample_size` 表示当 `training_with_video_token_length` 为 True 时,最大 token 长度对应的分辨率。 +- 由于配置可能产生混淆,**如果你不需要任意分辨率进行 finetuning**,建议将 `video_sample_size`、`image_sample_size` 和 `token_sample_size` 设置为相同的固定值,例如 **(320, 480, 512, 640, 960)**。 + - **全部设置为 320** 代表 **240P**。 + - **全部设置为 480** 代表 **320P**。 + - **全部设置为 640** 代表 **480P**。 + - **全部设置为 960** 代表 **720P**。 + +**Token Length 训练说明**: +- 当启用 `training_with_video_token_length` 时,模型根据 token 长度进行训练。 +- 例如:512x512 分辨率、49 帧的视频,其 token 长度为 13,312,需要设置 `token_sample_size = 512`。 + - 在 512x512 分辨率下,视频帧数为 49 (~= 512 * 512 * 49 / 512 / 512)。 + - 在 768x768 分辨率下,视频帧数为 21 (~= 512 * 512 * 49 / 768 / 768)。 + - 在 1024x1024 分辨率下,视频帧数为 9 (~= 512 * 512 * 49 / 1024 / 1024)。 + - 这些分辨率与对应帧数的组合,使模型能够生成不同尺寸的视频。 + +### 3.4 训练验证 + +你可以配置验证参数,在训练过程中定期生成测试视频,以便监控训练进度和模型质量。 + +**验证参数说明**: + +| 参数 | 说明 | 推荐值 | +|------|------|--------| +| `--validation_steps` | 每 N 步执行一次验证 | 2000 | +| `--validation_epochs` | 每 N 个epoch执行一次验证 | 5 | +| `--validation_prompts` | 验证视频生成的提示词 | 英文提示词 | + +**T2V 验证示例**: + +```bash + --validation_steps=2000 \ + --validation_epochs=5 \ + --validation_prompts="A brown dog shaking its head, sitting on a light-colored sofa in a cozy room. Behind the dog, there's a framed painting on a shelf, surrounded by pink flowers. The soft, warm lighting in the room creates a comfortable atmosphere." +``` + +**注意事项**: +- 验证视频会保存到 `output_dir` 目录中 +- 多提示词验证格式:`--validation_prompts "prompt1" "prompt2" "prompt3"` + +### 3.5 使用 FSDP 训练 + +**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.6 其他后端 + +#### 3.6.1 使用DeepSpeed-Zero-3进行训练 + +目前不太推荐使用 DeepSpeed Zero-3。在本仓库中,使用 FSDP 出错更少且更稳定。 + +DeepSpeed Zero-3 适合高分辨率的 14B Wan。训练后,您可以使用以下命令获取最终模型: +```bash +python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization +``` + +训练 shell 命令如下: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +#### 3.6.2 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练Shell用于参考训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +### 3.7 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.7.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.1_self_forcing/train_distill.py \ + --config_path="config/wan2.1/wan_civitai.yaml" \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --image_sample_size=640 \ + --video_sample_size=640 \ + --token_sample_size=640 \ + --fix_sample_size 480 832 \ + --video_sample_stride=2 \ + --video_sample_n_frames=81 \ + --train_batch_size=1 \ + --video_repeat=1 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-06 \ + --learning_rate_critic=2e-07 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_wan2.1_self_forcing_distill" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --random_hw_adapt \ + --training_with_video_token_length \ + --enable_bucket \ + --uniform_sampling \ + --train_mode="normal" \ + --trainable_modules "." \ + --use_teacher_forcing \ + --low_vram +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-T2V-1.3B/" +export DATASET_NAME="datasets/X-Fun-Videos-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.7.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +--- + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `sequential_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `True` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/Wan2.1-T2V-1.3B` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `"models/Diffusion_Transformer/Self-Forcing/checkpoints/self_forcing_dmd.pt"` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成视频分辨率 `[高度, 宽度]` | `[480, 832]` | +| `video_length` | 生成视频帧数 | `81` | +| `fps` | 每秒帧数 | `16` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `num_frame_per_block` | 每个块生成的帧数(1 为标准因果,更高则更快但需更多显存) | 3 | +| `local_attn_size` | 局部注意力窗口大小(-1 为全局注意力) | -1 | +| `independent_first_frame` | 第一帧是否独立生成 | `False` | +| `context_noise` | 生成时的上下文噪声级别 | 0.0 | +| `prompt` | 正向提示词,描述生成内容 | `"A stylish woman walks down a Tokyo street..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | 中文负向提示词 | +| `guidance_scale` | 引导强度(蒸馏模型通常使用 1.0) | 1.0 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数(蒸馏模型通常为 4) | 4 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成视频保存路径 | `samples/wan-videos-self-forcing-t2v` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 文生视频(T2V)推理 + +单卡推理运行如下命令: + +```bash +python examples/wan2.1_self_forcing/predict_t2v.py +``` + +根据需求修改编辑 `examples/wan2.1_self_forcing/predict_t2v.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "sequential_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/Wan2.1-T2V-1.3B" +# 训练好的权重路径 +transformer_path = "models/Diffusion_Transformer/Self-Forcing/checkpoints/self_forcing_dmd.pt" +# 蒸馏模型通常使用 4 步 +num_inference_steps = 4 +# 蒸馏模型 guidance_scale 通常为 1.0 +guidance_scale = 1.0 + +# Self-Forcing 因果推理配置 +num_frame_per_block = 3 # 每个块生成的帧数 +local_attn_size = -1 # 局部注意力窗口大小(-1 为全局注意力) +independent_first_frame = False +context_noise = 0.0 + +# 根据生成内容编写 +prompt = "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/wan2.1_self_forcing/predict_t2v.py`: + +```python +# 确保 ulysses_degree × ring_degree = 使用的 GPU 数 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的 head 数 +- `ring_degree` 是在 sequence 维度切分,会影响通信开销,在 head 能整除的情况下尽量不要用 + +**配置示例**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单 GPU | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/wan2.1_self_forcing/predict_t2v.py +``` + +--- + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py index 5375c010..f415ec23 100644 --- a/videox_fun/pipeline/pipeline_wan_self_forcing.py +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -564,9 +564,6 @@ def __call__( else: timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps + 1) # 5. Prepare latents (noise) and output buffer separately latent_channels = self.transformer.config.in_channels @@ -604,9 +601,6 @@ def __call__( device=device, dtype=weight_dtype ) - - if comfyui_progressbar: - pbar.update(1) # 6. Calculate sequence length and frame_seq_length target_shape = ( @@ -640,6 +634,13 @@ def __call__( f"num_latent_frames-1 ({num_latent_frames - 1}) must be divisible by num_frame_per_block ({num_frame_per_block})" num_blocks = (num_latent_frames - 1) // num_frame_per_block + # Initialize ComfyUI progress bar after calculating num_blocks + if comfyui_progressbar: + from comfy.utils import ProgressBar + # Total steps = num_blocks * num_inference_steps + 1 (for latent preparation) + pbar = ProgressBar(num_blocks * num_inference_steps + 1) + pbar.update(1) + # Self-Forcing causal state (reset per call) current_start_frame = start_frame_index cache_start_frame = 0 @@ -688,86 +689,91 @@ def __call__( if hasattr(self.scheduler, 'model_outputs'): self.scheduler.model_outputs = [] - for step_idx, t in enumerate(timesteps): - - # Per-frame timesteps for causal generation - timestep = torch.ones([batch_size, current_num_frames], device=device, dtype=torch.long) * t - - if do_classifier_free_guidance: - # Conditional path - with torch.cuda.amp.autocast(dtype=weight_dtype): - flow_pred_cond = self.transformer( - x=noisy_input, - context=prompt_embeds, - t=timestep, - seq_len=seq_len, - kv_cache=self.kv_cache_pos, - crossattn_cache=self.crossattn_cache_pos, - current_start=current_start_frame * frame_seq_length, - cache_start=None, - ) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for step_idx, t in enumerate(timesteps): + # Per-frame timesteps for causal generation + timestep = torch.ones([batch_size, current_num_frames], device=device, dtype=torch.long) * t + + if comfyui_progressbar: + pbar.update(1) + + if do_classifier_free_guidance: + # Conditional path + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred_cond = self.transformer( + x=noisy_input, + context=prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # Unconditional path + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred_uncond = self.transformer( + x=noisy_input, + context=negative_prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_neg, + crossattn_cache=self.crossattn_cache_neg, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # CFG guidance + # Transformer output shape check + if flow_pred_cond.dim() == 5: + # Already [B, C, F, H, W] + flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) + elif flow_pred_cond.dim() == 4: + # [F, C, H, W], need to add batch dim + flow_pred_cond = flow_pred_cond.unsqueeze(0).permute(0, 2, 1, 3, 4) + flow_pred_uncond = flow_pred_uncond.unsqueeze(0).permute(0, 2, 1, 3, 4) + flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) + else: + raise ValueError(f"Unexpected flow_pred_cond dim: {flow_pred_cond.dim()}, shape: {flow_pred_cond.shape}") + else: + # Forward pass with KV cache + with torch.cuda.amp.autocast(dtype=weight_dtype): + flow_pred = self.transformer( + x=noisy_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + kv_cache=self.kv_cache_pos, + crossattn_cache=self.crossattn_cache_pos, + current_start=current_start_frame * frame_seq_length, + cache_start=None, + ) + + # Transformer output shape check + if flow_pred.dim() == 4: + # [F, C, H, W], need to add batch dim and permute + flow_pred = flow_pred.unsqueeze(0).permute(0, 2, 1, 3, 4) + # If already 5D [B, C, F, H, W], no need to permute - # Unconditional path - with torch.cuda.amp.autocast(dtype=weight_dtype): - flow_pred_uncond = self.transformer( - x=noisy_input, - context=negative_prompt_embeds, - t=timestep, - seq_len=seq_len, - kv_cache=self.kv_cache_neg, - crossattn_cache=self.crossattn_cache_neg, - current_start=current_start_frame * frame_seq_length, - cache_start=None, - ) + # Get current sigma for x0 conversion + sigma_t = self.scheduler.sigmas[step_idx] - # CFG guidance - # Transformer output shape check - if flow_pred_cond.dim() == 5: - # Already [B, C, F, H, W] - flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) - elif flow_pred_cond.dim() == 4: - # [F, C, H, W], need to add batch dim - flow_pred_cond = flow_pred_cond.unsqueeze(0).permute(0, 2, 1, 3, 4) - flow_pred_uncond = flow_pred_uncond.unsqueeze(0).permute(0, 2, 1, 3, 4) - flow_pred = flow_pred_uncond + guidance_scale * (flow_pred_cond - flow_pred_uncond) + # Convert to x0: x0 = x_t - sigma_t * flow_pred (matches original wan_wrapper.py line 192) + denoised_pred = noisy_input - sigma_t * flow_pred # [B*F, C, H, W] + + if step_idx < len(timesteps) - 1: + # Not the last step: add noise for next timestep + next_t = timesteps[step_idx + 1] + + # Add noise using flow matching formula: x_{t+1} = (1-sigma_{t+1}) * x0 + sigma_{t+1} * noise + next_sigma = self.scheduler.sigmas[step_idx + 1] + local_noise = torch.randn(denoised_pred.shape, device=denoised_pred.device, dtype=denoised_pred.dtype, generator=generator) + noisy_input = (1 - next_sigma) * denoised_pred + next_sigma * local_noise else: - raise ValueError(f"Unexpected flow_pred_cond dim: {flow_pred_cond.dim()}, shape: {flow_pred_cond.shape}") - else: - # Forward pass with KV cache - with torch.cuda.amp.autocast(dtype=weight_dtype): - flow_pred = self.transformer( - x=noisy_input, - context=in_prompt_embeds, - t=timestep, - seq_len=seq_len, - kv_cache=self.kv_cache_pos, - crossattn_cache=self.crossattn_cache_pos, - current_start=current_start_frame * frame_seq_length, - cache_start=None, - ) + noisy_input = denoised_pred - # Transformer output shape check - if flow_pred.dim() == 4: - # [F, C, H, W], need to add batch dim and permute - flow_pred = flow_pred.unsqueeze(0).permute(0, 2, 1, 3, 4) - # If already 5D [B, C, F, H, W], no need to permute - - # Get current sigma for x0 conversion - sigma_t = self.scheduler.sigmas[step_idx] - - # Convert to x0: x0 = x_t - sigma_t * flow_pred (matches original wan_wrapper.py line 192) - denoised_pred = noisy_input - sigma_t * flow_pred # [B*F, C, H, W] - - if step_idx < len(timesteps) - 1: - # Not the last step: add noise for next timestep - next_t = timesteps[step_idx + 1] - - # Add noise using flow matching formula: x_{t+1} = (1-sigma_{t+1}) * x0 + sigma_{t+1} * noise - next_sigma = self.scheduler.sigmas[step_idx + 1] - local_noise = torch.randn(denoised_pred.shape, device=denoised_pred.device, dtype=denoised_pred.dtype, generator=generator) - noisy_input = (1 - next_sigma) * denoised_pred + next_sigma * local_noise - else: - noisy_input = denoised_pred + progress_bar.update() # Update output with denoised block output[:, :, cache_start_frame:cache_start_frame + current_num_frames] = denoised_pred @@ -823,9 +829,6 @@ def __call__( callback_outputs = callback_on_step_end(self, block_idx, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - if comfyui_progressbar: - pbar.update(1) - # 9. Decode output if output_type == "pil": From b72917690f32b056580201acc687f5217f5933af Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 16:59:28 +0800 Subject: [PATCH 11/16] Update ernie_image --- examples/ernie_image/predict_t2i.py | 210 +++ scripts/ernie_image/README_TRAIN.md | 525 ++++++ scripts/ernie_image/README_TRAIN_zh-CN.md | 525 ++++++ scripts/ernie_image/train.py | 1595 ++++++++++++++++++ scripts/ernie_image/train.sh | 35 + videox_fun/dist/__init__.py | 1 + videox_fun/dist/ernie_image_xfuser.py | 98 ++ videox_fun/dist/ltx2_xfuser.py | 18 - videox_fun/models/__init__.py | 1 + videox_fun/models/ernie_image_transformer.py | 500 ++++++ videox_fun/pipeline/__init__.py | 1 + videox_fun/pipeline/pipeline_ernie_image.py | 414 +++++ 12 files changed, 3905 insertions(+), 18 deletions(-) create mode 100644 examples/ernie_image/predict_t2i.py create mode 100644 scripts/ernie_image/README_TRAIN.md create mode 100644 scripts/ernie_image/README_TRAIN_zh-CN.md create mode 100644 scripts/ernie_image/train.py create mode 100644 scripts/ernie_image/train.sh create mode 100644 videox_fun/dist/ernie_image_xfuser.py create mode 100644 videox_fun/models/ernie_image_transformer.py create mode 100644 videox_fun/pipeline/pipeline_ernie_image.py diff --git a/examples/ernie_image/predict_t2i.py b/examples/ernie_image/predict_t2i.py new file mode 100644 index 00000000..8da5745f --- /dev/null +++ b/examples/ernie_image/predict_t2i.py @@ -0,0 +1,210 @@ +import os +import sys + +import torch +from diffusers import FlowMatchEulerDiscreteScheduler + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKLFlux2, AutoTokenizer, + ErnieImageTransformer2DModel, Mistral3Model) +from videox_fun.pipeline import ErnieImagePipeline +from videox_fun.utils import (register_auto_device_hook, + safe_enable_group_offload) +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_group_offload transfers internal layer groups between CPU/CUDA, +# balancing memory efficiency and speed between full-module and leaf-level offloading methods. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/ERNIE-Image" + +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [1728, 992] + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +negative_prompt = "低分辨率,低画质,肢体畸形,手指畸形,画面过饱和,蜡像感,人脸无细节,过度光滑,画面具有AI感。构图混乱。文字模糊,扭曲。" +guidance_scale = 4.5 +seed = 43 +num_inference_steps = 40 +lora_weight = 0.55 +save_path = "samples/ernie-image-t2i" + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) + +transformer = ErnieImageTransformer2DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +).to(weight_dtype) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Vae +vae = AutoencoderKLFlux2.from_pretrained( + model_name, + subfolder="vae" +).to(weight_dtype) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get tokenizer and text_encoder +tokenizer = AutoTokenizer.from_pretrained( + model_name, subfolder="tokenizer" +) +text_encoder = Mistral3Model.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=weight_dtype +) + +# Get Scheduler +Chosen_Scheduler = scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = ErnieImagePipeline( + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, +) + +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.layers)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + +if compile_dit: + for i in range(len(pipeline.transformer.layers)): + pipeline.transformer.layers[i] = torch.compile(pipeline.transformer.layers[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + sample = pipeline( + prompt, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + ).images + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + video_path = os.path.join(save_path, prefix + ".png") + image = sample[0] + image.save(video_path) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/scripts/ernie_image/README_TRAIN.md b/scripts/ernie_image/README_TRAIN.md new file mode 100644 index 00000000..9703a20c --- /dev/null +++ b/scripts/ernie_image/README_TRAIN.md @@ -0,0 +1,525 @@ +# ERNIE-Image Full Parameter Training Guide + +This document provides a complete workflow for full parameter training of ERNIE-Image Diffusion Transformer, including environment configuration, data preparation, distributed training, and inference testing. + +--- + +## Table of Contents +- [1. Environment Configuration](#1-environment-configuration) +- [2. Data Preparation](#2-data-preparation) + - [2.1 Quick Test Dataset](#21-quick-test-dataset) + - [2.2 Dataset Structure](#22-dataset-structure) + - [2.3 metadata.json Format](#23-metadatajson-format) + - [2.4 Relative vs Absolute Path Usage](#24-relative-vs-absolute-path-usage) +- [3. Full Parameter Training](#3-full-parameter-training) + - [3.1 Download Pretrained Model](#31-download-pretrained-model) + - [3.2 Quick Start (DeepSpeed-Zero-2)](#32-quick-start-deepspeed-zero-2) + - [3.3 Common Training Parameters](#33-common-training-parameters) + - [3.4 Training Validation](#34-training-validation) + - [3.5 Training with FSDP](#35-training-with-fsdp) + - [3.6 Other Backends](#36-other-backends) + - [3.7 Multi-Machine Distributed Training](#37-multi-machine-distributed-training) +- [4. Inference Testing](#4-inference-testing) + - [4.1 Inference Parameters](#41-inference-parameters) + - [4.2 Single GPU Inference](#42-single-gpu-inference) + - [4.3 Multi-GPU Parallel Inference](#43-multi-gpu-parallel-inference) +- [5. Additional Resources](#5-additional-resources) + +--- + +## 1. Environment Configuration + +**Method 1: Using requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**Method 2: Manual Dependency Installation** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**Method 3: Using Docker** + +When using Docker, please ensure that the GPU driver and CUDA environment are correctly installed on your machine, then execute the following commands: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 2. Data Preparation + +### 2.1 Quick Test Dataset + +We provide a test dataset containing several training samples. + +```bash +# Download official demo dataset +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 Dataset Structure + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json Format + +**Relative Path Format** (example): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**Absolute Path Format**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**Key Fields Description**: +- `file_path`: Image path (relative or absolute) +- `text`: Image description (English prompt) +- `width` / `height`: Image dimensions (**recommended** to provide for bucket training; if not provided, they will be automatically read during training, which may slow down training when data is stored on slow systems like OSS) + - You can use `scripts/process_json_add_width_and_height.py` to add width and height fields to JSON files without these fields, supporting both images and videos + - Usage: `python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json` + +### 2.4 Relative vs Absolute Path Usage + +**Relative Paths**: + +If your data uses relative paths, configure the training script as follows: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**Absolute Paths**: + +If your data uses absolute paths, configure the training script as follows: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **Recommendation**: If the dataset is small and stored locally, use relative paths. If the dataset is stored on external storage (e.g., NAS, OSS) or shared across multiple machines, use absolute paths. + +--- + +## 3. Full Parameter Training + +### 3.1 Download Pretrained Model + +```bash +# Create model directory +mkdir -p models/Diffusion_Transformer + +# Download ERNIE-Image official weights +modelscope download --model PaddlePaddle/ERNIE-Image --local_dir models/Diffusion_Transformer/ERNIE-Image +``` + +### 3.2 Quick Start (DeepSpeed-Zero-2) + +If you have downloaded the data as per **2.1 Quick Test Dataset** and the weights as per **3.1 Download Pretrained Model**, you can directly copy and run the quick start command. + +DeepSpeed-Zero-2 and FSDP are recommended for training. Here we use DeepSpeed-Zero-2 as an example. + +The difference between DeepSpeed-Zero-2 and FSDP lies in whether the model weights are sharded. **If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 Common Training Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|-----|------|-------| +| `--pretrained_model_name_or_path` | Path to pretrained model | `models/Diffusion_Transformer/ERNIE-Image` | +| `--train_data_dir` | Training data directory | `datasets/internal_datasets/` | +| `--train_data_meta` | Training data metadata file | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | Samples per batch | 1 | +| `--image_sample_size` | Maximum training resolution, auto bucketing | 1328 | +| `--gradient_accumulation_steps` | Gradient accumulation steps (equivalent to larger batch) | 1 | +| `--dataloader_num_workers` | DataLoader subprocesses | 8 | +| `--num_train_epochs` | Number of training epochs | 100 | +| `--checkpointing_steps` | Save checkpoint every N steps | 50 | +| `--learning_rate` | Initial learning rate | 2e-05 | +| `--lr_scheduler` | Learning rate scheduler | `constant_with_warmup` | +| `--lr_warmup_steps` | Learning rate warmup steps | 100 | +| `--seed` | Random seed | 42 | +| `--output_dir` | Output directory | `output_dir_ernie_image` | +| `--gradient_checkpointing` | Enable activation checkpointing | - | +| `--mixed_precision` | Mixed precision: `fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW weight decay | 3e-2 | +| `--adam_epsilon` | AdamW epsilon value | 1e-10 | +| `--vae_mini_batch` | Mini-batch size for VAE encoding | 1 | +| `--max_grad_norm` | Gradient clipping threshold | 0.05 | +| `--enable_bucket` | Enable bucket training: trains entire images grouped by resolution without center cropping | - | +| `--random_hw_adapt` | Auto-scale images to random size in range `[512, image_sample_size]` | - | +| `--resume_from_checkpoint` | Resume training from checkpoint path, use `"latest"` to auto-select latest | None | +| `--uniform_sampling` | Uniform timestep sampling | - | +| `--trainable_modules` | Trainable modules (`"."` means all modules) | `"."` | +| `--validation_steps` | Execute validation every N steps | 100 | +| `--validation_epochs` | Execute validation every N epochs | 100 | +| `--validation_prompts` | Prompts used during validation | `"a young girl..."` | + + +### 3.4 Training Validation + +You can configure validation parameters to periodically generate test images during training, allowing you to monitor training progress and model quality. + +**Validation Parameters**: + +| Parameter | Description | Recommended Value | +|-----------|-------------|-------------------| +| `--validation_steps` | Execute validation every N steps | 100 | +| `--validation_epochs` | Execute validation every N epochs | 100 | +| `--validation_prompts` | Prompt for validation image generation. Use multiple space-separated prompt strings | Space-separated prompt strings | + +**Example**: + +```bash + --validation_steps=100 \ + --validation_epochs=100 \ + --validation_prompts="a young girl with flowing long hair, wearing a white halter dress" +``` + +**Notes**: +- Validation images will be saved to the `output_dir` directory +- For multi-prompt validation, use: `--validation_prompts "prompt1" "prompt2" "prompt3"` + +### 3.5 Training with FSDP + +**If VRAM is insufficient when using multiple GPUs with DeepSpeed-Zero-2**, you can switch to FSDP. + +```sh +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ErnieImageSharedAdaLNBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 Training Without DeepSpeed or FSDP + +**This approach is not recommended as it lacks VRAM-saving backends and may easily cause out-of-memory errors**. This is provided for reference only. + +```sh +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.7 Multi-Machine Distributed Training + +**Suitable for**: Ultra-large-scale datasets, faster training speed + +#### 3.7.1 Environment Configuration + +Assuming 2 machines with 8 GPUs each: + +**Machine 0 (Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master machine IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # Total number of machines +export NUM_PROCESS=16 # Total processes = machines × 8 +export RANK=0 # Current machine rank (0 or 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**Machine 1 (Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Same as Master +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # Note this is 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# Use the same accelerate launch command as Machine 0 +``` + +#### 3.7.2 Multi-Machine Training Notes + +- **Network Requirements**: + - RDMA/InfiniBand recommended (high performance) + - Without RDMA, add environment variables: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **Data Synchronization**: All machines must be able to access the same data paths (NFS/shared storage) + +## 4. Inference Testing + +### 4.1 Inference Parameters + +**Key Parameter Descriptions**: + +| Parameter | Description | Example Value | +|------|------|-------| +| `GPU_memory_mode` | GPU memory mode, see table below for options | `model_cpu_offload` | +| `ulysses_degree` | Head dimension parallelization degree, 1 for single GPU | 1 | +| `ring_degree` | Sequence dimension parallelization degree, 1 for single GPU | 1 | +| `fsdp_dit` | Use FSDP for Transformer in multi-GPU inference to save VRAM | `False` | +| `fsdp_text_encoder` | Use FSDP for text encoder in multi-GPU inference | `False` | +| `compile_dit` | Compile Transformer to accelerate inference (effective at fixed resolution) | `False` | +| `model_name` | Model path | `models/Diffusion_Transformer/ERNIE-Image` | +| `sampler_name` | Sampler type: `Flow`, `Flow_Unipc`, `Flow_DPM++` | `Flow` | +| `transformer_path` | Path to trained Transformer weights | `None` | +| `vae_path` | Path to trained VAE weights | `None` | +| `lora_path` | LoRA weights path | `None` | +| `sample_size` | Generated image resolution `[height, width]` | `[1728, 992]` | +| `weight_dtype` | Model weight precision, use `torch.float16` for GPUs without bf16 support | `torch.bfloat16` | +| `prompt` | Positive prompt describing the content to generate | `"1girl, black_hair..."` | +| `negative_prompt` | Negative prompt for content to avoid | `"低分辨率,低画质..."` | +| `guidance_scale` | Guidance strength | 4.5 | +| `seed` | Random seed for reproducibility | 43 | +| `num_inference_steps` | Inference steps | 40 | +| `lora_weight` | LoRA weight strength | 0.55 | +| `save_path` | Generated image save path | `samples/ernie-image-t2i` | + +**GPU Memory Mode Description**: + +| Mode | Description | VRAM Usage | +|------|------|---------| +| `model_full_load` | Load entire model to GPU | Highest | +| `model_full_load_and_qfloat8` | Full load + FP8 quantization | High | +| `model_cpu_offload` | Offload model to CPU after use | Medium | +| `model_cpu_offload_and_qfloat8` | CPU offload + FP8 quantization | Medium-Low | +| `model_group_offload` | Layer group offload between CPU/CUDA | Low | +| `sequential_cpu_offload` | Offload each layer individually (slowest) | Lowest | + +### 4.2 Single GPU Inference + +Run single GPU inference with: + +```bash +python examples/ernie_image/predict_t2i.py +``` + +Edit `examples/ernie_image/predict_t2i.py` according to your needs. For first-time inference, focus on these parameters. For other parameters, see the Inference Parameters section above. + +```python +# Choose based on your GPU VRAM +GPU_memory_mode = "model_cpu_offload" +# Your actual model path +model_name = "models/Diffusion_Transformer/ERNIE-Image" +# Trained weights path, e.g. "output_dir_ernie_image/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# Write based on content to generate +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 Multi-GPU Parallel Inference + +**Suitable for**: High-resolution generation, accelerated inference + +#### Install Parallel Inference Dependencies + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### Configure Parallel Strategy + +Edit `examples/ernie_image/predict_t2i.py`: + +```python +# Ensure ulysses_degree × ring_degree = number of GPUs +# For example, using 2 GPUs: +ulysses_degree = 2 # Head dimension parallelization +ring_degree = 1 # Sequence dimension parallelization +``` + +**Configuration Principles**: +- `ulysses_degree` must evenly divide the model's number of heads +- `ring_degree` splits on sequence dimension, affecting communication overhead; avoid using it when heads can be divided + +**Example Configurations**: + +| GPU Count | ulysses_degree | ring_degree | Description | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | Single GPU | +| 4 | 4 | 1 | Head parallelization | +| 8 | 8 | 1 | Head parallelization | +| 8 | 4 | 2 | Hybrid parallelization | + +#### Run Multi-GPU Inference + +```bash +torchrun --nproc-per-node=2 examples/ernie_image/predict_t2i.py +``` + +## 5. Additional Resources + +- **Official GitHub**: https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/ernie_image/README_TRAIN_zh-CN.md b/scripts/ernie_image/README_TRAIN_zh-CN.md new file mode 100644 index 00000000..db12d30b --- /dev/null +++ b/scripts/ernie_image/README_TRAIN_zh-CN.md @@ -0,0 +1,525 @@ +# ERNIE-Image 全量参数训练指南 + +本文档提供 ERNIE-Image Diffusion Transformer 全量参数训练的完整流程,包括环境配置、数据准备、分布式训练和推理测试。 + +--- + +## 目录 +- [一、环境配置](#一环境配置) +- [二、数据准备](#二数据准备) + - [2.1 快速测试数据集](#21-快速测试数据集) + - [2.2 数据集结构](#22-数据集结构) + - [2.3 metadata.json 格式](#23-metadatajson-格式) + - [2.4 相对路径与绝对路径使用方案](#24-相对路径与绝对路径使用方案) +- [三、全量参数训练](#三全量参数训练) + - [3.1 下载预训练模型](#31-下载预训练模型) + - [3.2 快速开始(DeepSpeed-Zero-2)](#32-快速开始deepspeed-zero-2) + - [3.3 训练常用参数解析](#33-训练常用参数解析) + - [3.4 训练验证](#34-训练验证) + - [3.5 使用 FSDP 训练](#35-使用-fsdp-训练) + - [3.6 其他后端](#36-其他后端) + - [3.7 多机分布式训练](#37-多机分布式训练) +- [四、推理测试](#四推理测试) + - [4.1 推理参数解析](#41-推理参数解析) + - [4.2 单卡推理](#42-单卡推理) + - [4.3 多卡并行推理](#43-多卡并行推理) +- [五、更多资源](#五更多资源) + +--- + +## 一、环境配置 + +**方式 1:使用requirements.txt** + +```bash +pip install -r requirements.txt +``` + +**方式 2:手动安装依赖** + +```bash +pip install Pillow einops safetensors timm tomesd librosa "torch>=2.1.2" torchdiffeq torchsde decord datasets numpy scikit-image +pip install omegaconf SentencePiece imageio[ffmpeg] imageio[pyav] tensorboard beautifulsoup4 ftfy func_timeout onnxruntime +pip install "peft>=0.17.0" "accelerate>=0.25.0" "gradio>=3.41.2" "diffusers>=0.30.1" "transformers>=4.46.2" +pip install yunchang xfuser modelscope openpyxl deepspeed==0.17.0 numpy==1.26.4 +pip uninstall opencv-python opencv-contrib-python opencv-python-headless -y +pip install opencv-python-headless +``` + +**方式 3:使用docker** + +使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: + +``` +# pull image +docker pull mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun + +# enter image +docker run -it -p 7860:7860 --network host --gpus all --security-opt seccomp:unconfined --shm-size 200g mybigpai-public-registry.cn-beijing.cr.aliyuncs.com/easycv/torch_cuda:cogvideox_fun +``` + +--- + +## 二、数据准备 + +### 2.1 快速测试数据集 + +我们提供了一个测试的数据集,其中包含若干训练数据。 + +```bash +# 下载官方示例数据集 +modelscope download --dataset PAI/X-Fun-Images-Demo --local_dir ./datasets/X-Fun-Images-Demo +``` + +### 2.2 数据集结构 + +``` +📦 datasets/ +├── 📂 my_dataset/ +│ ├── 📂 train/ +│ │ ├── 📄 image001.jpg +│ │ ├── 📄 image002.png +│ │ └── 📄 ... +│ └── 📄 metadata.json +``` + +### 2.3 metadata.json 格式 + +**相对路径格式**(示例格式): +```json +[ + { + "file_path": "train/image001.jpg", + "text": "A beautiful sunset over the ocean, golden hour lighting", + "width": 1024, + "height": 1024 + }, + { + "file_path": "train/image002.png", + "text": "Portrait of a young woman, studio lighting, high quality", + "width": 1328, + "height": 1328 + } +] +``` + +**绝对路径格式**: +```json +[ + { + "file_path": "/mnt/data/images/sunset.jpg", + "text": "A beautiful sunset over the ocean", + "width": 1024, + "height": 1024 + } +] +``` + +**关键字段说明**: +- `file_path`:图片路径(相对或绝对路径) +- `text`:图片描述(英文提示词) +- `width` / `height`:图片宽高(**最好提供**,用于分桶训练,如果不提供则自动在训练时读取,当数据存储在如oss这样的速度较慢的系统上时,可能会影响训练速度)。 + - 可以使用`scripts/process_json_add_width_and_height.py`文件对无width与height字段的json进行提取,支持处理图片与视频。 + - 使用方案为`python scripts/process_json_add_width_and_height.py --input_file datasets/X-Fun-Images-Demo/metadata.json --output_file datasets/X-Fun-Images-Demo/metadata_add_width_height.json`。 + +### 2.4 相对路径与绝对路径使用方案 + +**相对路径**: + +如果数据的路径为相对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +``` + +**绝对路径**: + +如果数据的路径为绝对路径,则在训练脚本中设置: + +```bash +export DATASET_NAME="" +export DATASET_META_NAME="/mnt/data/metadata.json" +``` + +> 💡 **建议**:如果数据集较小且存储在本地,推荐使用相对路径;如果数据集存储在外部存储(如 NAS、OSS)或多个机器共享存储,推荐使用绝对路径。 + +--- + +## 三、全量参数训练 + +### 3.1 下载预训练模型 + +```bash +# 创建模型目录 +mkdir -p models/Diffusion_Transformer + +# 下载 ERNIE-Image 官方权重 +modelscope download --model PaddlePaddle/ERNIE-Image --local_dir models/Diffusion_Transformer/ERNIE-Image +``` + +### 3.2 快速开始(DeepSpeed-Zero-2) + +如果按照 **2.1 快速测试数据集下载数据** 与 **3.1 下载预训练模型下载权重**后,直接复制快速开始的启动指令进行启动。 + +推荐使用DeepSpeed-Zero-2与FSDP方案进行训练。这里使用DeepSpeed-Zero-2为例配置shell文件。 + +本文中DeepSpeed-Zero-2与FSDP的差别在于是否对模型权重进行分片,**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.3 训练常用参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|-----|------|-------| +| `--pretrained_model_name_or_path` | 预训练模型路径 | `models/Diffusion_Transformer/ERNIE-Image` | +| `--train_data_dir` | 训练数据目录 | `datasets/internal_datasets/` | +| `--train_data_meta` | 训练数据元文件 | `datasets/internal_datasets/metadata.json` | +| `--train_batch_size` | 每批次样本数 | 1 | +| `--image_sample_size` | 最大训练分辨率,代码会自动分桶 | 1328 | +| `--gradient_accumulation_steps` | 梯度累积步数(等效增大 batch) | 1 | +| `--dataloader_num_workers` | DataLoader 子进程数 | 8 | +| `--num_train_epochs` | 训练 epoch 数 | 100 | +| `--checkpointing_steps` | 每 N 步保存 checkpoint | 50 | +| `--learning_rate` | 初始学习率 | 2e-05 | +| `--lr_scheduler` | 学习率调度器 | `constant_with_warmup` | +| `--lr_warmup_steps` | 学习率预热步数 | 100 | +| `--seed` | 随机种子 | 42 | +| `--output_dir` | 输出目录 | `output_dir_ernie_image` | +| `--gradient_checkpointing` | 激活重计算 | - | +| `--mixed_precision` | 混合精度:`fp16/bf16` | `bf16` | +| `--adam_weight_decay` | AdamW 权重衰减 | 3e-2 | +| `--adam_epsilon` | AdamW epsilon 值 | 1e-10 | +| `--vae_mini_batch` | VAE 编码时的迷你批次大小 | 1 | +| `--max_grad_norm` | 梯度裁剪阈值 | 0.05 | +| `--enable_bucket` | 启用分桶训练,不裁剪图片,按分辨率分组训练整个图像 | - | +| `--random_hw_adapt` | 自动缩放图片到 `[512, image_sample_size]` 范围内的随机尺寸 | - | +| `--resume_from_checkpoint` | 恢复训练路径,使用 `"latest"` 自动选择最新 checkpoint | None | +| `--uniform_sampling` | 均匀采样 timestep | - | +| `--trainable_modules` | 可训练模块(`"."` 表示所有模块) | `"."` | +| `--validation_steps` | 每 N 步执行一次验证 | 100 | +| `--validation_epochs` | 每 N 个epoch执行一次验证 | 100 | +| `--validation_prompts` | 验证图像生成的提示词 | `"一位年轻女子..."` | + + +### 3.4 训练验证 + +你可以配置验证参数,在训练过程中定期生成测试图像,以便监控训练进度和模型质量。 + +**验证参数说明**: + +| 参数 | 说明 | 推荐值 | +|------|------|--------| +| `--validation_steps` | 每 N 步执行一次验证 | 100 | +| `--validation_epochs` | 每 N 个epoch执行一次验证 | 100 | +| `--validation_prompts` | 验证图像生成的提示词,可用空格分隔多个提示词 | 多个空格分隔的提示词 | + +**示例**: + +```bash + --validation_steps=100 \ + --validation_epochs=100 \ + --validation_prompts="一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。" +``` + +**注意事项**: +- 验证图像会保存到 `output_dir` 目录中 +- 多提示词验证格式:`--validation_prompts "prompt1" "prompt2" "prompt3"` + +### 3.5 使用 FSDP 训练 + +**如果使用多卡且使用DeepSpeed-Zero-2的情况下显存不足**,可以切换使用FSDP进行训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap ErnieImageSharedAdaLNBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.6 不使用 DeepSpeed 与 FSDP 训练 + +**该方案并不被推荐,因为没有显存节约后端,容易造成显存不足**。这里仅提供训练Shell用于参考训练。 + +```sh +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +### 3.7 多机分布式训练 + +**适合场景**:超大规模数据集、需要更快的训练速度 + +#### 3.7.1 环境配置 + +假设有 2 台机器,每台 8 张 GPU: + +**机器 0(Master)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # Master 机器 IP +export MASTER_PORT=10086 +export WORLD_SIZE=2 # 机器总数 +export NUM_PROCESS=16 # 总进程数 = 机器数 × 8 +export RANK=0 # 当前机器 rank(0 或 1) +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=50 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." +``` + +**机器 1(Worker)**: +```bash +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/X-Fun-Images-Demo/" +export DATASET_META_NAME="datasets/X-Fun-Images-Demo/metadata_add_width_height.json" +export MASTER_ADDR="192.168.1.100" # 与 Master 相同 +export MASTER_PORT=10086 +export WORLD_SIZE=2 +export NUM_PROCESS=16 +export RANK=1 # 注意这里是 1 +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +# 使用与机器 0 相同的 accelerate launch 命令 +``` + +#### 3.7.2 多机训练注意事项 + +- **网络要求**: + - 推荐 RDMA/InfiniBand(高性能) + - 无 RDMA 时添加环境变量: + ```bash + export NCCL_IB_DISABLE=1 + export NCCL_P2P_DISABLE=1 + ``` + +- **数据同步**:所有机器必须能够访问相同的数据路径(NFS/共享存储) + +## 四、推理测试 + +### 4.1 推理参数解析 + +**关键参数说明**: + +| 参数 | 说明 | 示例值 | +|------|------|-------| +| `GPU_memory_mode` | 显存管理模式,可选值见下表 | `model_cpu_offload` | +| `ulysses_degree` | Head 维度并行度,单卡时为 1 | 1 | +| `ring_degree` | Sequence 维度并行度,单卡时为 1 | 1 | +| `fsdp_dit` | 多卡推理时对 Transformer 使用 FSDP 节省显存 | `False` | +| `fsdp_text_encoder` | 多卡推理时对文本编码器使用 FSDP | `False` | +| `compile_dit` | 编译 Transformer 加速推理(固定分辨率下有效) | `False` | +| `model_name` | 模型路径 | `models/Diffusion_Transformer/ERNIE-Image` | +| `sampler_name` | 采样器类型:`Flow`、`Flow_Unipc`、`Flow_DPM++` | `Flow` | +| `transformer_path` | 加载训练好的 Transformer 权重路径 | `None` | +| `vae_path` | 加载训练好的 VAE 权重路径 | `None` | +| `lora_path` | LoRA 权重路径 | `None` | +| `sample_size` | 生成图像分辨率 `[高度, 宽度]` | `[1728, 992]` | +| `weight_dtype` | 模型权重精度,不支持 bf16 的显卡使用 `torch.float16` | `torch.bfloat16` | +| `prompt` | 正向提示词,描述生成内容 | `"1girl, black_hair..."` | +| `negative_prompt` | 负向提示词,避免生成的内容 | `"低分辨率,低画质..."` | +| `guidance_scale` | 引导强度 | 4.5 | +| `seed` | 随机种子,用于复现结果 | 43 | +| `num_inference_steps` | 推理步数 | 40 | +| `lora_weight` | LoRA 权重强度 | 0.55 | +| `save_path` | 生成图像保存路径 | `samples/ernie-image-t2i` | + +**显存管理模式说明**: + +| 模式 | 说明 | 显存占用 | +|------|------|---------| +| `model_full_load` | 整个模型加载到 GPU | 最高 | +| `model_full_load_and_qfloat8` | 全量加载 + FP8 量化 | 高 | +| `model_cpu_offload` | 使用后将模型卸载到 CPU | 中等 | +| `model_cpu_offload_and_qfloat8` | CPU 卸载 + FP8 量化 | 中低 | +| `model_group_offload` | 层组在 CPU/CUDA 间切换 | 低 | +| `sequential_cpu_offload` | 逐层卸载(速度最慢) | 最低 | + +### 4.2 单卡推理 + +单卡推理运行如下命令: + +```bash +python examples/ernie_image/predict_t2i.py +``` + +根据需求修改编辑 `examples/ernie_image/predict_t2i.py`,初次推理重点关注如下参数,如果对其他参数感兴趣,请查看上方的推理参数解析。 + +```python +# 根据显卡显存选择 +GPU_memory_mode = "model_cpu_offload" +# 根据实际模型路径 +model_name = "models/Diffusion_Transformer/ERNIE-Image" +# 训练好的权重路径,如 "output_dir_ernie_image/checkpoint-xxx/diffusion_pytorch_model.safetensors" +transformer_path = None +# 根据生成内容编写 +prompt = "1girl, black_hair, brown_eyes, earrings, freckles, grey_background, jewelry, lips, long_hair, looking_at_viewer, nose, piercing, realistic, red_lips, solo, upper_body" +# ... +``` + +### 4.3 多卡并行推理 + +**适合场景**:高分辨率生成、加速推理 + +#### 安装并行推理依赖 + +```bash +pip install xfuser==0.4.2 yunchang==0.6.2 +``` + +#### 配置并行策略 + +编辑 `examples/ernie_image/predict_t2i.py`: + +```python +# 确保 ulysses_degree × ring_degree = GPU 数量 +# 例如使用 2 张 GPU: +ulysses_degree = 2 # Head 维度并行 +ring_degree = 1 # Sequence 维度并行 +``` + +**配置原则**: +- `ulysses_degree` 必须能整除模型的head数。 +- `ring_degree` 会在sequence上切分,影响通信开销,在head数能切分的时候尽量不用。 + +**示例配置**: + +| GPU 数量 | ulysses_degree | ring_degree | 说明 | +|---------|---------------|-------------|------| +| 1 | 1 | 1 | 单卡 | +| 4 | 4 | 1 | Head 并行 | +| 8 | 8 | 1 | Head 并行 | +| 8 | 4 | 2 | 混合并行 | + +#### 运行多卡推理 + +```bash +torchrun --nproc-per-node=2 examples/ernie_image/predict_t2i.py +``` + +## 五、更多资源 + +- **官方 GitHub**:https://github.com/aigc-apps/VideoX-Fun diff --git a/scripts/ernie_image/train.py b/scripts/ernie_image/train.py new file mode 100644 index 00000000..be74e0f6 --- /dev/null +++ b/scripts/ernie_image/train.py @@ -0,0 +1,1595 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys +from typing import (Any, Callable, Dict, List, NamedTuple, Optional, Tuple, + Union) + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data import (ASPECT_RATIO_512, ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + ImageVideoDataset, ImageVideoSampler, + RandomSampler, get_closest_ratio, get_random_mask) +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKLFlux2, AutoTokenizer, + ErnieImageTransformer2DModel, + Mistral3Model) +from videox_fun.pipeline import ErnieImagePipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.utils import get_image_to_video_latent, save_videos_grid + +if is_wandb_available(): + import wandb + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def _patchify_latents(latents): + """2x2 patchify: [B, C, H, W] -> [B, 4*C, H/2, W/2]""" + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + +def encode_prompt( + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + text_encoder = None, + tokenizer = None, + num_images_per_prompt: int = 1, + dtype: Optional[torch.dtype] = None, + text_in_dim: int = None, +) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """Encode text prompts to embeddings (Ernie-Image style).""" + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + + for p in prompt: + ids = tokenizer( + p, + add_special_tokens=True, + truncation=True, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if tokenizer.bos_token_id is not None: + ids = [tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + # Use second to last hidden state + hidden = outputs.hidden_states[-2][0] # [T, H] + + # Repeat for num_images_per_prompt + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + # Pad text embeddings to batch tensor + B = len(text_hiddens) + if B == 0: + text_bth = torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype) + text_lens = torch.zeros((0,), device=device, dtype=torch.long) + else: + normalized = [ + th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens + ] + text_lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(text_lens.max().item()) + text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, : t.shape[0], :] = t + + return text_bth, text_lens + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, transformer3d, args, accelerator, weight_dtype, global_step): + try: + is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' + if is_deepspeed: + origin_config = transformer3d.config + transformer3d.config = accelerator.unwrap_model(transformer3d).config + with torch.no_grad(), torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + logger.info("Running validation... ") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + pipeline = ErnieImagePipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, + scheduler=scheduler, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + rank_seed = args.seed + accelerator.process_index + generator = torch.Generator(device=accelerator.device).manual_seed(rank_seed) + logger.info(f"Rank {accelerator.process_index} using seed: {rank_seed}") + + for i in range(len(args.validation_prompts)): + sample = pipeline( + args.validation_prompts[i], + negative_prompt = "bad detailed", + height = args.image_sample_size, + width = args.image_sample_size, + generator = generator, + guidance_scale = 0 if "turbo" in args.pretrained_model_name_or_path.lower() else 4.5, + num_inference_steps = 8 if "turbo" in args.pretrained_model_name_or_path.lower() else 25, + ).images + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + image = sample[0].save( + os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.jpg" + ) + ) + + del pipeline + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if is_deepspeed: + transformer3d.config = origin_config + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error on rank {accelerator.process_index} with info {e}") + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)." + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--prompt_template_encode", + type=str, + default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + help=( + 'The prompt template for text encoder.' + ), + ) + parser.add_argument( + "--prompt_template_encode_start_idx", + type=int, + default=34, + help=( + 'The start idx for prompt template.' + ), + ) + parser.add_argument( + "--abnormal_norm_clip_start", + type=int, + default=1000, + help=( + 'When do we start doing additional processing on abnormal gradients. ' + ), + ) + parser.add_argument( + "--initial_grad_norm_ratio", + type=int, + default=5, + help=( + 'The initial gradient is relative to the multiple of the max_grad_norm. ' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So Mistral3Model and AutoencoderKLFlux2 will not enjoy the parameter sharding + # across multiple gpus and only transformer3d will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = Mistral3Model.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae" + ).to(weight_dtype) + vae.eval() + latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(accelerator.device, weight_dtype) + latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(accelerator.device, weight_dtype) + + # Get Transformer + transformer3d = ErnieImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(weight_dtype) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # A good trainable modules is showed below now. + # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position'] + # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position'] + transformer3d.train() + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + + # Create EMA for the transformer3d. + if args.use_ema: + if zero_stage == 3: + raise NotImplementedError("FSDP does not support EMA.") + + ema_transformer3d = ErnieImageTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(weight_dtype) + + ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=ErnieImageTransformer2DModel, model_config=ema_transformer3d.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0 or zero_stage == 3: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + accelerate_state_dict = {k: v.to(dtype=weight_dtype) for k, v in accelerate_state_dict.items()} + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema")) + + models[0].save_pretrained(os.path.join(output_dir, "transformer")) + if not args.use_deepspeed: + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + if args.use_ema: + ema_path = os.path.join(input_dir, "transformer_ema") + _, ema_kwargs = ErnieImageTransformer2DModel.load_config(ema_path, return_unused_kwargs=True) + load_model = ErnieImageTransformer2DModel.from_pretrained( + input_dir, subfolder="transformer_ema", + ) + load_model = EMAModel(load_model.parameters(), model_cls=ErnieImageTransformer2DModel, model_config=load_model.config) + load_model.load_state_dict(ema_kwargs) + + ema_transformer3d.load_state_dict(load_model.state_dict()) + ema_transformer3d.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ErnieImageTransformer2DModel.from_pretrained( + input_dir, subfolder="transformer" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except Exception: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + if args.fix_sample_size is not None and args.enable_bucket: + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.random_hw_adapt = False + + # Get the dataset + train_dataset = ImageVideoDataset( + args.train_data_meta, args.train_data_dir, + image_sample_size=args.image_sample_size, + enable_bucket=args.enable_bucket, + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + def collate_fn(examples): + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + if length >= 2: + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + special_list = [first_element] + [other_elements_value] * (length - 1) + return special_list + + if sample_size >= 1536: + number_list = [1, 1.25, 1.5, 2, 2.5, 3] + image_ratio + elif sample_size >= 1024: + number_list = [1, 1.25, 1.5, 2] + image_ratio + elif sample_size >= 768: + number_list = [1, 1.25, 1.5] + image_ratio + elif sample_size >= 512: + number_list = [1] + image_ratio + else: + number_list = [1] + + if all_choices: + return number_list + + number_list_prob = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p = number_list_prob) + else: + return rng.choice(number_list, p = number_list_prob) + + # Create new output + new_examples = {} + new_examples["pixel_values"] = [] + new_examples["text"] = [] + + # Get downsample ratio in image + pixel_value = examples[0]["pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + + random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size) + + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + for example in examples: + if args.fix_sample_size is not None: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + new_examples["pixel_values"].append(transform(pixel_values)) + new_examples["text"].append(example["text"]) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_embeds, text_lens = encode_prompt( + batch['text'], device="cpu", + text_encoder=text_encoder, + tokenizer=tokenizer, + dtype=weight_dtype, + text_in_dim=transformer3d.config.text_in_dim, + ) + + new_examples['prompt_embeds'] = prompt_embeds + new_examples['text_lens'] = text_lens + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.vision_tower.transformer.layers) + text_encoder = shard_fn(text_encoder) + + if args.use_ema: + ema_transformer3d.to(accelerator.device) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need】 + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = pixel_values.squeeze(1) + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values) # [B, C, H, W] - NO unsqueeze + else: + latents = _batch_encode_vae(pixel_values) # [B, C, H, W] - NO unsqueeze + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + if args.low_vram: + vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds, text_lens = batch['prompt_embeds'].to(dtype=latents.dtype, device=accelerator.device), batch['text_lens'].to(accelerator.device) + else: + with torch.no_grad(): + prompt_embeds, text_lens = encode_prompt( + batch['text'], device=accelerator.device, + text_encoder=text_encoder, + tokenizer=tokenizer, + dtype=weight_dtype, + text_in_dim=transformer3d.config.text_in_dim, + ) + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + bsz, channel, height, width = latents.size() + + # Patchify: [B, 32, H, W] -> [B, 128, H/2, W/2] + latents = _patchify_latents(latents) + + # BN normalization + latents = ((latents - latents_bn_mean) / latents_bn_std).to(dtype=weight_dtype) + + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + + image_seq_len = latents.shape[1] + + # Setup scheduler with linear sigmas (matching inference pipeline) + sigmas_init = torch.linspace(1.0, 0.0, args.train_sampling_steps + 1) + noise_scheduler.set_timesteps(sigmas=sigmas_init[:-1], device=latents.device) + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + hidden_states=noisy_latents, + timestep=timesteps, + text_bth=prompt_embeds, + text_lens=text_lens, + return_dict=False, + )[0] + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + if not args.use_deepspeed and not args.use_fsdp: + trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None] + trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2) + max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step) + if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start: + actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10) + else: + actual_max_grad_norm = max_grad_norm + else: + actual_max_grad_norm = args.max_grad_norm + + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start: + for name, param in transformer3d.named_parameters(): + if param.requires_grad: + writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step) + + norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm) + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step) + writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + if args.use_ema: + ema_transformer3d.step(transformer3d.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/ernie_image/train.sh b/scripts/ernie_image/train.sh new file mode 100644 index 00000000..2c7b2d5a --- /dev/null +++ b/scripts/ernie_image/train.sh @@ -0,0 +1,35 @@ +export MODEL_NAME="models/Diffusion_Transformer/ERNIE-Image" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP \ + --fsdp_transformer_layer_cls_to_wrap ErnieImageSharedAdaLNBlock --fsdp_sharding_strategy "FULL_SHARD" \ + --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False \ + scripts/ernie_image/train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=100 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_ernie_image" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --trainable_modules "." \ No newline at end of file diff --git a/videox_fun/dist/__init__.py b/videox_fun/dist/__init__.py index 31b9c976..0427663e 100755 --- a/videox_fun/dist/__init__.py +++ b/videox_fun/dist/__init__.py @@ -1,6 +1,7 @@ import importlib.util from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0 +from .ernie_image_xfuser import ErnieImageMultiGPUsAttnProcessor from .flashhead_xfuser import usp_attn_flashhead_forward from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0 from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0 diff --git a/videox_fun/dist/ernie_image_xfuser.py b/videox_fun/dist/ernie_image_xfuser.py new file mode 100644 index 00000000..e10012da --- /dev/null +++ b/videox_fun/dist/ernie_image_xfuser.py @@ -0,0 +1,98 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.models.attention import Attention + +from .fuser import xFuserLongContextAttention + + +class ErnieImageMultiGPUsAttnProcessor: + """ + Processor for Ernie-Image multi-GPU inference using sequence parallel attention. + + This processor adapts the single-stream attention mechanism to work with + xFuserLongContextAttention for distributed inference across multiple GPUs. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ErnieImageMultiGPUsAttnProcessor requires PyTorch 2.0. " + "To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Step 1: QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + # Reshape to [batch, seq_len, heads, head_dim] + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Step 2: Apply QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Step 3: Apply rotary positional embeddings (RoPE) + # Same rotate_half logic as ErnieImageSingleStreamAttnProcessor (rotary_interleaved=False) + if freqs_cis is not None: + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Step 4: Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # Step 5: Handle attention mask format conversion if needed + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Step 6: Perform distributed attention using xFuserLongContextAttention + # This handles sequence parallelism automatically + half_dtypes = (torch.float16, torch.bfloat16) + + def half(x): + return x if x.dtype in half_dtypes else x.to(torch.bfloat16) + + hidden_states = xFuserLongContextAttention()( + None, + half(query), + half(key), + half(value), + dropout_p=0.0, + causal=False, + ) + + # Step 7: Reshape back and project output + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + + return output diff --git a/videox_fun/dist/ltx2_xfuser.py b/videox_fun/dist/ltx2_xfuser.py index 12c80df1..13dbb44c 100644 --- a/videox_fun/dist/ltx2_xfuser.py +++ b/videox_fun/dist/ltx2_xfuser.py @@ -1,21 +1,3 @@ -# Copyright 2025 The VideoX-Fun Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Multi-GPU sequence parallel attention processors for LTX2 transformer. -""" - from typing import Tuple import os diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 419388b1..66970bf2 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -29,6 +29,7 @@ from .cogvideox_transformer3d import CogVideoXTransformer3DModel from .cogvideox_vae import AutoencoderKLCogVideoX +from .ernie_image_transformer import ErnieImageTransformer2DModel from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel from .flashhead_audio_encoder import FlashHeadAudioEncoder diff --git a/videox_fun/models/ernie_image_transformer.py b/videox_fun/models/ernie_image_transformer.py new file mode 100644 index 00000000..32960e4f --- /dev/null +++ b/videox_fun/models/ernie_image_transformer.py @@ -0,0 +1,500 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Transformer2DModel for HuggingFace Diffusers. +""" + +import inspect +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import BaseOutput, logging + +from .attention_utils import attention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class ErnieImageTransformer2DModelOutput(BaseOutput): + sample: torch.Tensor + + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos, omega) + return out.float() + + +class ErnieImageEmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + + +class ErnieImagePatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() + + +class ErnieImageSingleStreamAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + freqs_cis: torch.Tensor | None = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False) + # x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...] + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = torch.cos(freqs_cis).to(x.dtype) + sin_ = torch.sin(freqs_cis).to(x.dtype) + # Non-interleaved rotate_half: [-x2, x1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + # Compute joint attention + hidden_states = attention( + query, + key, + value, + attn_mask=attention_mask, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + output = attn.to_out[0](hidden_states) + + return output + + +class ErnieImageAttention(nn.Module): + _default_processor_cls = ErnieImageSingleStreamAttnProcessor + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: str = "rms_norm", + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.out_dim = out_dim if out_dim is not None else query_dim + self.heads = out_dim // dim_head if out_dim is not None else heads + + self.use_bias = bias + self.dropout = dropout + + self.added_proj_bias = added_proj_bias + + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + # QK Norm + if qk_norm == "layer_norm": + self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + elif qk_norm == "rms_norm": + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'." + ) + + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def set_processor(self, processor) -> None: + """ + Set the attention processor to use. + + Args: + processor: The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + image_rotary_emb: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +class ErnieImageFeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int): + super().__init__() + # Separate gate and up projections (matches converted weights) + self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + + +class ErnieImageSharedAdaLNBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True + ): + super().__init__() + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps) + self.self_attention = ErnieImageAttention( + query_dim=hidden_size, + dim_head=hidden_size // num_heads, + heads=num_heads, + qk_norm="rms_norm" if qk_layernorm else None, + eps=eps, + bias=False, + out_bias=False, + processor=ErnieImageSingleStreamAttnProcessor(), + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size) + + def forward( + self, + x, + rotary_pos_emb, + temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + ): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb + residual = x + x = self.adaLN_sa_ln(x) + x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first) + attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) + attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H] + x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + residual = x + x = self.adaLN_mlp_ln(x) + x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype) + + +class ErnieImageAdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps) + self.linear = nn.Linear(hidden_size, hidden_size * 2) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + # Broadcast conditioning to sequence dimension + x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0) + return x + + +class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + _repeated_blocks = ["ErnieImageSharedAdaLNBlock"] + + @register_to_config + def __init__( + self, + hidden_size: int = 3072, + num_attention_heads: int = 24, + num_layers: int = 24, + ffn_hidden_size: int = 8192, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 2560, + rope_theta: int = 256, + rope_axes_dim: Tuple[int, int, int] = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.num_layers = num_layers + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels + self.text_in_dim = text_in_dim + + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size) + self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size) + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)) + nn.init.zeros_(self.adaLN_modulation[-1].weight) + nn.init.zeros_(self.adaLN_modulation[-1].bias) + self.layers = nn.ModuleList( + [ + ErnieImageSharedAdaLNBlock( + hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm + ) + for _ in range(num_layers) + ] + ) + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps) + self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + self.gradient_checkpointing = False + + # Multi-GPU inference support + self.sp_world_size = 1 + self.sp_world_rank = 0 + + def enable_multi_gpus_inference(self): + """Enable multi-GPU inference using sequence parallelism.""" + from ..dist import (ErnieImageMultiGPUsAttnProcessor, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, get_sp_group) + + self.sp_world_size = get_sequence_parallel_world_size() + self.sp_world_rank = get_sequence_parallel_rank() + self.all_gather = get_sp_group().all_gather + self.set_attn_processor(ErnieImageMultiGPUsAttnProcessor()) + + def set_attn_processor(self, processor): + """Set attention processor for all attention layers. + + Args: + processor: The attention processor to use for all attention layers. + """ + for name, module in self.named_modules(): + if hasattr(module, "set_processor"): + module.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + # encoder_hidden_states: List[torch.Tensor], + text_bth: torch.Tensor, + text_lens: torch.Tensor, + return_dict: bool = True, + ): + device, dtype = hidden_states.device, hidden_states.dtype + B, C, H, W = hidden_states.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + # Store original N_img for sequence parallel + N_img_full = N_img + + img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous() + # text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype) + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + text_sbh = text_bth.transpose(0, 1).contiguous() + + # Sequence parallel: chunk image tokens across GPUs + if self.sp_world_size > 1: + N_img = N_img // self.sp_world_size + img_sbh = torch.chunk(img_sbh, self.sp_world_size, dim=0)[self.sp_world_rank] + + x = torch.cat([img_sbh, text_sbh], dim=0) + S = x.shape[0] + + # Position IDs + text_ids = ( + torch.cat( + [ + torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1), + torch.zeros((B, Tmax, 2), device=device), + ], + dim=-1, + ) + if Tmax > 0 + else torch.zeros((B, 0, 3), device=device) + ) + grid_yx = torch.stack( + torch.meshgrid( + torch.arange(Hp, device=device, dtype=torch.float32), + torch.arange(Wp, device=device, dtype=torch.float32), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + + # Sequence parallel: use only the image_ids chunk for this GPU + if self.sp_world_size > 1: + chunk_start = self.sp_world_rank * N_img + chunk_end = chunk_start + N_img + image_ids = torch.cat( + [text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), + grid_yx[chunk_start:chunk_end].view(1, N_img, 2).expand(B, -1, -1)], + dim=-1, + ) + else: + image_ids = torch.cat( + [text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)], + dim=-1, + ) + + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) + + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention + valid_text = ( + torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + if Tmax > 0 + else torch.zeros((B, 0), device=device, dtype=torch.bool) + ) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + + # AdaLN + sample = self.time_proj(timestep) + sample = sample.to(dtype=dtype) + c = self.time_embedding(sample) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] + for layer in self.layers: + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func( + layer, + x, + rotary_pos_emb, + temb, + attention_mask, + ) + else: + x = layer(x, rotary_pos_emb, temb, attention_mask) + x = self.final_norm(x, c).type_as(x) + patches = self.final_linear(x) + + # Sequence parallel: gather image patches from all GPUs + if self.sp_world_size > 1: + # Only gather the image part (first N_img tokens) + img_patches = patches[:N_img] + img_patches = self.all_gather(img_patches, dim=0) + # Reconstruct full patches: [full_img_tokens, text_tokens] + patches = torch.cat([img_patches, patches[N_img:]], dim=0) + # Use full N_img for output reshape + N_img = N_img_full + + output = ( + patches[:N_img].transpose(0, 1).contiguous() + .view(B, Hp, Wp, p, p, self.out_channels) + .permute(0, 5, 1, 3, 2, 4) + .contiguous() + .view(B, self.out_channels, H, W) + ) + + return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,) \ No newline at end of file diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py index 65d97b67..86eb74d1 100755 --- a/videox_fun/pipeline/__init__.py +++ b/videox_fun/pipeline/__init__.py @@ -1,6 +1,7 @@ from .pipeline_cogvideox_fun import CogVideoXFunPipeline from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline +from .pipeline_ernie_image import ErnieImagePipeline from .pipeline_fantasytalking import FantasyTalkingPipeline from .pipeline_flashhead import FlashHeadPipeline from .pipeline_flux import FluxPipeline diff --git a/videox_fun/pipeline/pipeline_ernie_image.py b/videox_fun/pipeline/pipeline_ernie_image.py new file mode 100644 index 00000000..dfdbe256 --- /dev/null +++ b/videox_fun/pipeline/pipeline_ernie_image.py @@ -0,0 +1,414 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Ernie-Image Pipeline for HuggingFace Diffusers. +""" + +import json +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLFlux2, AutoTokenizer, + ErnieImageTransformer2DModel, Ministral3ForCausalLM, + Mistral3Model) + +if not hasattr(PIL.Image, "Image"): + raise ImportError("`ErnieImagePipeline` requires `PIL.Image`. Please install it with: `pip install Pillow`.") + + +@dataclass +class ErnieImagePipelineOutput(BaseOutput): + """ + Output class for Ernie-Image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + revised_prompts (`List[str]`, *optional*): + List of revised prompts after PE enhancement. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + revised_prompts: Optional[List[str]] = None + + +class ErnieImagePipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using ErnieImageTransformer2DModel. + + This pipeline uses: + - A custom DiT transformer model + - A Flux2-style VAE for encoding/decoding latents + - A text encoder (e.g., Qwen) for text conditioning + - Flow Matching Euler Discrete Scheduler + """ + + model_cpu_offload_seq = "pe->text_encoder->transformer->vae" + # For SGLang fallback ... + _optional_components = ["pe", "pe_tokenizer"] + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + transformer: ErnieImageTransformer2DModel, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3Model, + tokenizer: AutoTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + pe: Optional[Ministral3ForCausalLM] = None, + pe_tokenizer: Optional[AutoTokenizer] = None, + ): + super().__init__() + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + pe=pe, + pe_tokenizer=pe_tokenizer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @torch.no_grad() + def _enhance_prompt_with_pe( + self, + prompt: str, + device: torch.device, + width: int = 1024, + height: int = 1024, + system_prompt: Optional[str] = None, + temperature: float = 0.6, + top_p: float = 0.95, + ) -> str: + """Use PE model to rewrite/enhance a short prompt via chat_template.""" + # Build user message as JSON carrying prompt text and target resolution + user_content = json.dumps( + {"prompt": prompt, "width": width, "height": height}, + ensure_ascii=False, + ) + messages = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_content}) + + # apply_chat_template picks up the chat_template.jinja loaded with pe_tokenizer + input_text = self.pe_tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=False, # "Output:" is already in the user block + ) + inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device) + output_ids = self.pe.generate( + **inputs, + max_new_tokens=self.pe_tokenizer.model_max_length, + do_sample=temperature != 1.0 or top_p != 1.0, + temperature=temperature, + top_p=top_p, + pad_token_id=self.pe_tokenizer.pad_token_id, + eos_token_id=self.pe_tokenizer.eos_token_id, + ) + # Decode only newly generated tokens + generated_ids = output_ids[0][inputs["input_ids"].shape[1] :] + return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + @torch.no_grad() + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_images_per_prompt: int = 1, + ) -> List[torch.Tensor]: + """Encode text prompts to embeddings.""" + if isinstance(prompt, str): + prompt = [prompt] + + text_hiddens = [] + + for p in prompt: + ids = self.tokenizer( + p, + add_special_tokens=True, + truncation=True, + padding=False, + )["input_ids"] + + if len(ids) == 0: + if self.tokenizer.bos_token_id is not None: + ids = [self.tokenizer.bos_token_id] + else: + ids = [0] + + input_ids = torch.tensor([ids], device=device) + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=input_ids, + output_hidden_states=True, + ) + # Use second to last hidden state (matches training) + hidden = outputs.hidden_states[-2][0] # [T, H] + + # Repeat for num_images_per_prompt + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + + return text_hiddens + + @staticmethod + def _patchify_latents(latents: torch.Tensor) -> torch.Tensor: + """2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]""" + b, c, h, w = latents.shape + latents = latents.view(b, c, h // 2, 2, w // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + return latents.reshape(b, c * 4, h // 2, w // 2) + + @staticmethod + def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor: + """Reverse patchify: [B, 128, H/2, W/2] -> [B, 32, H, W]""" + b, c, h, w = latents.shape + latents = latents.reshape(b, c // 4, 2, 2, h, w) + latents = latents.permute(0, 1, 4, 2, 5, 3) + return latents.reshape(b, c // 4, h * 2, w * 2) + + @staticmethod + def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int): + B = len(text_hiddens) + if B == 0: + return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros( + (0,), device=device, dtype=torch.long + ) + normalized = [ + th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens + ] + lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + Tmax = int(lens.max().item()) + text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + text_bth[i, : t.shape[0], :] = t + return text_bth, lens + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = "", + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: list[torch.FloatTensor] | None = None, + negative_prompt_embeds: list[torch.FloatTensor] | None = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + use_pe: bool = True, # 默认使用PE进行改写 + ): + """ + Generate images from text prompts. + + Args: + prompt: Text prompt(s) + negative_prompt: Negative prompt(s) for CFG. Default is "". + height: Image height in pixels (must be divisible by 16). Default: 1024. + width: Image width in pixels (must be divisible by 16). Default: 1024. + num_inference_steps: Number of denoising steps + guidance_scale: CFG scale (1.0 = no guidance). Default: 4.0. + num_images_per_prompt: Number of images per prompt + generator: Random generator for reproducibility + latents: Pre-generated latents (optional) + prompt_embeds: Pre-computed text embeddings for positive prompts (optional). + If provided, `encode_prompt` is skipped for positive prompts. + negative_prompt_embeds: Pre-computed text embeddings for negative prompts (optional). + If provided, `encode_prompt` is skipped for negative prompts. + output_type: "pil" or "latent" + return_dict: Whether to return a dataclass + callback_on_step_end: Optional callback invoked at the end of each denoising step. + Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where `callback_kwargs` + contains the tensors listed in `callback_on_step_end_tensor_inputs`. The callback may return a dict to + override those tensors for subsequent steps. + callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs. + Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`). + use_pe: Whether to use the PE model to enhance prompts before generation. + + Returns: + :class:`ErnieImagePipelineOutput` with `images` and `revised_prompts`. + """ + device = self._execution_device + dtype = self.transformer.dtype + + self._guidance_scale = guidance_scale + + # Validate prompt / prompt_embeds + if prompt is None and prompt_embeds is None: + raise ValueError("Must provide either `prompt` or `prompt_embeds`.") + if prompt is not None and prompt_embeds is not None: + raise ValueError("Cannot provide both `prompt` and `prompt_embeds` at the same time.") + + # Validate dimensions + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}") + + # Handle prompts + if prompt is not None: + if isinstance(prompt, str): + prompt = [prompt] + + # [Phase 1] PE: enhance prompts + revised_prompts: Optional[List[str]] = None + if prompt is not None and use_pe and self.pe is not None and self.pe_tokenizer is not None: + prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt] + revised_prompts = list(prompt) + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + total_batch_size = batch_size * num_images_per_prompt + + # Handle negative prompt + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if len(negative_prompt) != batch_size: + raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})") + + # [Phase 2] Text encoding + if prompt_embeds is not None: + text_hiddens = [h for h in prompt_embeds for _ in range(num_images_per_prompt)] + else: + text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt) + + # CFG with negative prompt + if self.do_classifier_free_guidance: + if negative_prompt_embeds is not None: + uncond_text_hiddens = [h for h in negative_prompt_embeds for _ in range(num_images_per_prompt)] + else: + uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt) + + # Latent dimensions + latent_h = height // self.vae_scale_factor + latent_w = width // self.vae_scale_factor + latent_channels = self.transformer.config.in_channels # After patchify + + # Initialize latents + if latents is None: + latents = randn_tensor( + (total_batch_size, latent_channels, latent_h, latent_w), + generator=generator, + device=device, + dtype=dtype, + ) + + # Setup scheduler + sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1) + self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device) + + # Denoising loop + if self.do_classifier_free_guidance: + cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens) + else: + cfg_text_hiddens = text_hiddens + text_bth, text_lens = self._pad_text( + text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(self.scheduler.timesteps): + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents, latents], dim=0) + t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype) + else: + latent_model_input = latents + t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype) + + # Model prediction + pred = self.transformer( + hidden_states=latent_model_input, + timestep=t_batch, + text_bth=text_bth, + text_lens=text_lens, + return_dict=False, + )[0] + + # Apply CFG + if self.do_classifier_free_guidance: + pred_uncond, pred_cond = pred.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + + # Scheduler step + latents = self.scheduler.step(pred, t, latents).prev_sample + + # Callback + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + progress_bar.update() + + if output_type == "latent": + images = latents + else: + # Decode latents to images + # Unnormalize latents using VAE's BN stats + # TODO: switch to `self.vae.config.batch_norm_eps` once the hub config is updated to match the trained value (1e-5). + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to( + device=device, dtype=latents.dtype + ) + latents = latents * bn_std + bn_mean + + # Unpatchify + latents = self._unpatchify_latents(latents) + + # Decode + images = self.vae.decode(latents, return_dict=False)[0] + + # Post-process + images = self.image_processor.postprocess(images, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (images,) + + return ErnieImagePipelineOutput(images=images, revised_prompts=revised_prompts) \ No newline at end of file From af6668511eb0c4f484d8a4fbcc09dd5c88e606cb Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 17:05:46 +0800 Subject: [PATCH 12/16] Update __init__ --- videox_fun/models/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 66970bf2..550fad73 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -25,7 +25,14 @@ from transformers import Qwen3VLForConditionalGeneration except: Qwen3VLForConditionalGeneration = None - print("Your transformers version is too old to load Qwen3VLForConditionalGeneration. If you wish to use QwenImage, please upgrade your transformers package to the latest version.") + print("Your transformers version is too old to load Qwen3VLForConditionalGeneration. If you wish to use Qwen3VLForConditionalGeneration, please upgrade your transformers package to the latest version.") + +try: + from transformers import Mistral3Model, Ministral3ForCausalLM +except: + Mistral3Model = None + Ministral3ForCausalLM = None + print("Your transformers version is too old to load Mistral3Model and Ministral3ForCausalLM. If you wish to use ErnieImage, please upgrade your transformers package to the latest version.") from .cogvideox_transformer3d import CogVideoXTransformer3DModel from .cogvideox_vae import AutoencoderKLCogVideoX From 5ccf2d27469005ddd733c4f7dff01d53602afa4a Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 17:46:27 +0800 Subject: [PATCH 13/16] Update md --- scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md | 2 +- scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md index 9c8dfd09..d1d30798 100755 --- a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md @@ -345,7 +345,7 @@ export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.j # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=CasualWanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ --config_path="config/wan2.1/wan_civitai.yaml" \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md index f453e5fd..2141cfac 100755 --- a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md @@ -346,7 +346,7 @@ export DATASET_META_NAME="datasets/X-Fun-Videos-Demo/metadata_add_width_height.j # export NCCL_P2P_DISABLE=1 NCCL_DEBUG=INFO -accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ +accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=CasualWanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.1_self_forcing/train_distill.py \ --config_path="config/wan2.1/wan_civitai.yaml" \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ From 3098d234d2b3d3351f406555fdd997e4bf2d234a Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 13 May 2026 18:13:15 +0800 Subject: [PATCH 14/16] Update comment --- videox_fun/models/ernie_image_transformer.py | 1 + videox_fun/pipeline/pipeline_ernie_image.py | 1 + 2 files changed, 2 insertions(+) diff --git a/videox_fun/models/ernie_image_transformer.py b/videox_fun/models/ernie_image_transformer.py index 32960e4f..b75e7ce5 100644 --- a/videox_fun/models/ernie_image_transformer.py +++ b/videox_fun/models/ernie_image_transformer.py @@ -1,3 +1,4 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_ernie_image.py # Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/videox_fun/pipeline/pipeline_ernie_image.py b/videox_fun/pipeline/pipeline_ernie_image.py index dfdbe256..730643ad 100644 --- a/videox_fun/pipeline/pipeline_ernie_image.py +++ b/videox_fun/pipeline/pipeline_ernie_image.py @@ -1,3 +1,4 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py # Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); From a5631085c314a4be5635a2dde03a4cc6893650df Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 14 May 2026 15:33:36 +0800 Subject: [PATCH 15/16] Support self-forcing with different res --- scripts/wan2.1_self_forcing/train_distill.py | 13 ++++++++++++- videox_fun/models/wan_transformer3d_self_forcing.py | 9 ++++++--- videox_fun/pipeline/pipeline_wan_self_forcing.py | 11 ++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/scripts/wan2.1_self_forcing/train_distill.py b/scripts/wan2.1_self_forcing/train_distill.py index e219cf4c..f25f8bcf 100644 --- a/scripts/wan2.1_self_forcing/train_distill.py +++ b/scripts/wan2.1_self_forcing/train_distill.py @@ -235,7 +235,7 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, transformer if args.fix_sample_size is not None: height, width = args.fix_sample_size else: - height, width = args.video_sample_size + height, width = args.video_sample_size, args.video_sample_size sample = pipeline( args.validation_prompts[i], num_frames = args.video_sample_n_frames, @@ -1327,6 +1327,17 @@ def collate_fn(examples): # Magvae needs the number of frames to be 4n + 1. batch_video_length = (batch_video_length - 1) // sample_n_frames_bucket_interval * sample_n_frames_bucket_interval + 1 + # KV cache training requires latent frames divisible by num_frame_per_block + if args.use_kv_cache_training: + k = (batch_video_length - 1) // sample_n_frames_bucket_interval + if args.independent_first_frame: + # latent_frames - 1 = k must be divisible by num_frame_per_block + k = (k // args.num_frame_per_block) * args.num_frame_per_block + else: + # latent_frames = k + 1 must be divisible by num_frame_per_block + k = ((k + 1) // args.num_frame_per_block) * args.num_frame_per_block - 1 + batch_video_length = k * sample_n_frames_bucket_interval + 1 + if batch_video_length <= 0: batch_video_length = 1 diff --git a/videox_fun/models/wan_transformer3d_self_forcing.py b/videox_fun/models/wan_transformer3d_self_forcing.py index 6de2e3d2..fa84d3e1 100644 --- a/videox_fun/models/wan_transformer3d_self_forcing.py +++ b/videox_fun/models/wan_transformer3d_self_forcing.py @@ -111,7 +111,6 @@ def __init__(self, self.eps = eps self.local_attn_size = local_attn_size self.sink_size = sink_size - self.max_attention_size = 32760 if local_attn_size == -1 else local_attn_size * 1560 # Layers self.q = nn.Linear(dim, dim) @@ -286,10 +285,14 @@ def qkv_fn(x): kv_cache["v"][:, local_start_index:local_end_index] = v # Compute attention with local window + if self.local_attn_size == -1: + max_attention_size = local_end_index + else: + max_attention_size = self.local_attn_size * frame_seqlen x = attention( roped_query, - kv_cache["k"][:, max(0, local_end_index - self.max_attention_size):local_end_index], - kv_cache["v"][:, max(0, local_end_index - self.max_attention_size):local_end_index] + kv_cache["k"][:, max(0, local_end_index - max_attention_size):local_end_index], + kv_cache["v"][:, max(0, local_end_index - max_attention_size):local_end_index] ) kv_cache["global_end_index"].fill_(current_end) kv_cache["local_end_index"].fill_(local_end_index) diff --git a/videox_fun/pipeline/pipeline_wan_self_forcing.py b/videox_fun/pipeline/pipeline_wan_self_forcing.py index f415ec23..2cb56607 100644 --- a/videox_fun/pipeline/pipeline_wan_self_forcing.py +++ b/videox_fun/pipeline/pipeline_wan_self_forcing.py @@ -387,18 +387,18 @@ def attention_kwargs(self): def interrupt(self): return self._interrupt - def _initialize_kv_cache(self, batch_size, dtype, device, frame_seq_length): + def _initialize_kv_cache(self, batch_size, dtype, device, frame_seq_length, num_latent_frames): """ Initialize KV cache for causal self-attention. """ kv_cache_pos = [] kv_cache_neg = [] - # Use the default KV cache size (32760 tokens for global attention) + # Compute KV cache size based on actual resolution and frame count local_attn_size = getattr(self.transformer.config, 'local_attn_size', -1) if local_attn_size != -1: kv_cache_size = local_attn_size * frame_seq_length else: - kv_cache_size = 32760 + kv_cache_size = num_latent_frames * frame_seq_length num_heads = self.transformer.config.num_heads head_dim = self.transformer.config.dim // num_heads @@ -647,7 +647,8 @@ def __call__( # 8. Initialize KV cache and cross-attention cache # Reset caches if they exist (for multiple inference calls) - if self.kv_cache_pos is not None: + required_kv_size = num_latent_frames * frame_seq_length + if self.kv_cache_pos is not None and self.kv_cache_pos[0]["k"].shape[1] >= required_kv_size: for block_index in range(len(self.kv_cache_pos)): self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor( [0], dtype=torch.long, device=device) @@ -661,7 +662,7 @@ def __call__( self.crossattn_cache_pos[block_index]["is_init"] = False self.crossattn_cache_neg[block_index]["is_init"] = False else: - self._initialize_kv_cache(batch_size=batch_size, dtype=weight_dtype, device=device, frame_seq_length=frame_seq_length) + self._initialize_kv_cache(batch_size=batch_size, dtype=weight_dtype, device=device, frame_seq_length=frame_seq_length, num_latent_frames=num_latent_frames) self._initialize_crossattn_cache(batch_size=batch_size, dtype=weight_dtype, device=device) # Build all_num_frames list From 41259a9b02a1abb6188e1ccefe83731766766eef Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Tue, 19 May 2026 17:14:38 +0800 Subject: [PATCH 16/16] Update self-forcing training code --- .../README_TRAIN_DISTILL.md | 31 +++++--- .../README_TRAIN_DISTILL_zh-CN.md | 31 +++++--- scripts/wan2.1_self_forcing/train_distill.py | 72 +++++++++++-------- scripts/wan2.1_self_forcing/train_distill.sh | 6 +- 4 files changed, 88 insertions(+), 52 deletions(-) diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md index d1d30798..c8f6b291 100755 --- a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL.md @@ -203,7 +203,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -218,9 +218,11 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -277,6 +279,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con | `--gen_update_interval` | Generator update interval | 5 | | `--negative_prompt` | Negative prompt for distillation | Chinese negative prompt | | `--train_sampling_steps` | Training sampling steps | 1000 | +| `--ode_transformer_path` | Path to ODE-trained weights to load into generator transformer3d | `models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt` | **Self-Forcing-Specific Parameters**: @@ -363,7 +366,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -378,9 +381,11 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -423,7 +428,7 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -438,9 +443,11 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -475,7 +482,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -490,9 +497,11 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -537,7 +546,7 @@ accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -552,9 +561,11 @@ accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` diff --git a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md index 2141cfac..053be37c 100755 --- a/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md +++ b/scripts/wan2.1_self_forcing/README_TRAIN_DISTILL_zh-CN.md @@ -204,7 +204,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -219,9 +219,11 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -278,6 +280,7 @@ accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_con | `--gen_update_interval` | 生成器更新间隔 | 5 | | `--negative_prompt` | 用于蒸馏的负向提示词 | 中文负向提示词 | | `--train_sampling_steps` | 训练采样步数 | 1000 | +| `--ode_transformer_path` | ODE 训练权重路径,加载到 generator transformer3d 中 | `models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt` | **Self-Forcing 特有参数**: @@ -364,7 +367,7 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -379,9 +382,11 @@ accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TR --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -424,7 +429,7 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -439,9 +444,11 @@ accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -476,7 +483,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -491,9 +498,11 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` @@ -538,7 +547,7 @@ accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -553,9 +562,11 @@ accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram ``` diff --git a/scripts/wan2.1_self_forcing/train_distill.py b/scripts/wan2.1_self_forcing/train_distill.py index f25f8bcf..d56eca52 100644 --- a/scripts/wan2.1_self_forcing/train_distill.py +++ b/scripts/wan2.1_self_forcing/train_distill.py @@ -668,6 +668,12 @@ def parse_args(): default=None, help=("If you want to load the weight from other transformers, input its path."), ) + parser.add_argument( + "--ode_transformer_path", + type=str, + default=None, + help=("If you want to load the ode-trained weight into generator transformer3d, input its path."), + ) parser.add_argument( "--vae_path", type=str, @@ -725,7 +731,7 @@ def parse_args(): parser.add_argument( "--real_guidance_scale", type=float, - default=6.0, + default=4.5, help="The cfg scale for real score.", ) parser.add_argument( @@ -981,6 +987,9 @@ def deepspeed_zero_init_disabled_context_manager(): else: state_dict = torch.load(args.transformer_path, map_location="cpu") state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + state_dict = state_dict["generator_ema"] if "generator_ema" in state_dict else state_dict + if any(k.startswith("model.") for k in state_dict.keys()): + state_dict = {k.replace("model.", "", 1) if k.startswith("model.") else k: v for k, v in state_dict.items()} m, u = generator_transformer3d.load_state_dict(state_dict, strict=False) m, u = real_score_transformer3d.load_state_dict(state_dict, strict=False) @@ -988,6 +997,23 @@ def deepspeed_zero_init_disabled_context_manager(): print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") assert len(u) == 0 + if args.ode_transformer_path is not None: + print(f"From ode checkpoint: {args.ode_transformer_path}") + if args.ode_transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.ode_transformer_path) + else: + state_dict = torch.load(args.ode_transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + state_dict = state_dict["generator_ema"] if "generator_ema" in state_dict else state_dict + state_dict = state_dict["generator"] if "generator" in state_dict else state_dict + if any(k.startswith("model.") for k in state_dict.keys()): + state_dict = {k.replace("model.", "", 1) if k.startswith("model.") else k: v for k, v in state_dict.items()} + + m, u = generator_transformer3d.load_state_dict(state_dict, strict=False) + print(f"ode_transformer_path loaded into generator_transformer3d. missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + if args.vae_path is not None: print(f"From checkpoint: {args.vae_path}") if args.vae_path.endswith("safetensors"): @@ -1242,7 +1268,7 @@ def get_length_to_frame_num(token_length): return length_to_frame_num - if args.enable_bucket: + if (args.enable_bucket and args.train_mode != "normal") or args.use_teacher_forcing: aspect_ratio_sample_size = {key : [x / 512 * args.video_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} batch_sampler_generator = torch.Generator().manual_seed(args.seed) batch_sampler = AspectRatioBatchImageVideoSampler( @@ -1823,11 +1849,6 @@ def _batch_encode_vae(pixel_values): if vae_stream_2 is not None: torch.cuda.current_stream().wait_stream(vae_stream_2) - # Encode clean latents for teacher forcing - clean_latents = None - if args.use_teacher_forcing: - clean_latents = _batch_encode_vae(pixel_values) - mask = rearrange(mask, "b f c h w -> b c f h w") mask = torch.concat( [ @@ -2015,8 +2036,6 @@ def convert_flow_pred_to_x0( bsz, channel, num_frames, height, width = target_shape if step % args.gen_update_interval == 0: if args.use_kv_cache_training: - # === KV cache block-by-block training (original Self-Forcing) === - # Calculate frame_seq_length patch_h, patch_w = accelerator.unwrap_model(generator_transformer3d).config.patch_size[1:] frame_seq_length = (target_shape[3] * target_shape[4]) // (patch_h * patch_w) @@ -2075,16 +2094,16 @@ def convert_flow_pred_to_x0( torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob ) + # Same exit step across all blocks (matches original Self-Forcing default) + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=accelerator.device)[0] + for block_idx, current_num_frames in enumerate(all_num_frames): # Extract noise for current block start_idx = current_start_frame - num_input_frames end_idx = start_idx + current_num_frames noisy_input = generator_noise[:, :, start_idx:end_idx] - # Denoise loop for current block - num_denoising_steps = len(denoising_step_list) - final_step_index = generate_and_sync_list(num_denoising_steps, device=noisy_input.device)[0] - for index, current_timestep in enumerate(denoising_step_list): is_final_step = (index == final_step_index) timestep = torch.full( @@ -2144,19 +2163,15 @@ def convert_flow_pred_to_x0( # Record output output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = generator_pred_block - # Update KV cache with clean context (teacher forcing) or noisy context + # Update KV cache with clean context (consistent with inference: feed denoised_pred directly) if block_idx < len(all_num_frames) - 1: context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise - # Use clean latents for teacher forcing, otherwise add noise + # Use clean latents for teacher forcing, otherwise use denoised prediction directly if use_teacher_forcing_step and clean_latents is not None: context_input = clean_latents[:, :, start_idx:end_idx] else: - context_input = add_noise( - generator_pred_block, - torch.randn(generator_pred_block.shape, dtype=generator_pred_block.dtype, device=generator_pred_block.device, generator=torch_rng), - context_timestep[:, 0] - ) + context_input = generator_pred_block context_input_list = [context_input[i] for i in range(bsz)] @@ -2446,14 +2461,15 @@ def convert_flow_pred_to_x0( torch.rand(1, generator=torch_rng, device=accelerator.device).item() < args.teacher_forcing_prob ) + # Same exit step across all blocks (matches original Self-Forcing default) + num_denoising_steps = len(denoising_step_list) + final_step_index = generate_and_sync_list(num_denoising_steps, device=accelerator.device)[0] + for block_idx, current_num_frames in enumerate(all_num_frames): start_idx = current_start_frame - num_input_frames end_idx = start_idx + current_num_frames noisy_input = fake_score_critic_noise[:, :, start_idx:end_idx] - num_denoising_steps = len(denoising_step_list) - final_step_index = generate_and_sync_list(num_denoising_steps, device=noisy_input.device)[0] - for index, current_timestep in enumerate(denoising_step_list): is_final_step = (index == final_step_index) timestep = torch.full( @@ -2509,19 +2525,15 @@ def convert_flow_pred_to_x0( output_pred[:, :, current_start_frame:current_start_frame + current_num_frames] = fake_score_denoised_pred_block - # Update KV cache with clean context (teacher forcing) or noisy context + # Update KV cache with clean context (consistent with inference: feed denoised_pred directly) if block_idx < len(all_num_frames) - 1: context_timestep = torch.ones([bsz, current_num_frames], device=accelerator.device, dtype=torch.int64) * args.context_noise - # Use clean latents for teacher forcing, otherwise add noise + # Use clean latents for teacher forcing, otherwise use denoised prediction directly if use_teacher_forcing_step and clean_latents is not None: context_input = clean_latents[:, :, start_idx:end_idx] else: - context_input = add_noise( - fake_score_denoised_pred_block, - torch.randn(fake_score_denoised_pred_block.shape, dtype=fake_score_denoised_pred_block.dtype, device=fake_score_denoised_pred_block.device, generator=torch_rng), - context_timestep[:, 0] - ) + context_input = fake_score_denoised_pred_block context_input_list = [context_input[i] for i in range(bsz)] diff --git a/scripts/wan2.1_self_forcing/train_distill.sh b/scripts/wan2.1_self_forcing/train_distill.sh index f324652d..27e2d0e8 100644 --- a/scripts/wan2.1_self_forcing/train_distill.sh +++ b/scripts/wan2.1_self_forcing/train_distill.sh @@ -24,7 +24,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --num_train_epochs=100 \ --checkpointing_steps=50 \ --learning_rate=2e-06 \ - --learning_rate_critic=2e-07 \ + --learning_rate_critic=4e-07 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ @@ -39,7 +39,9 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.1_self_forcing/train_dis --training_with_video_token_length \ --enable_bucket \ --uniform_sampling \ + --use_kv_cache_training \ + --num_frame_per_block=3 \ --train_mode="normal" \ --trainable_modules "." \ - --use_teacher_forcing \ + --ode_transformer_path="models/Diffusion_Transformer/Self-Forcing/checkpoints/ode_init.pt" \ --low_vram