diff --git a/.ipynb_checkpoints/command_train-checkpoint.sh b/.ipynb_checkpoints/command_train-checkpoint.sh
new file mode 100644
index 0000000..30a17c3
--- /dev/null
+++ b/.ipynb_checkpoints/command_train-checkpoint.sh
@@ -0,0 +1,23 @@
+CUDA_VISIBLE_DEVICES=0,1 accelerate launch /root/StableAnimator_Music/train_single.py \
+ --pretrained_model_name_or_path="/root/StableAnimator_Music/checkpoints/stable-video-diffusion-img2vid-xt" \
+ --output_dir="./checkpoints/Animation" \
+ --data_root_path="/root/aist_hdf5/rec" \
+ --data_path="/root/aist_hdf5/full_list.txt" \
+ --dataset_width=512 \
+ --dataset_height=512 \
+ --validation_image_folder="./validation/ground_truth" \
+ --validation_control_folder="./validation/poses" \
+ --validation_image="./validation/reference.png" \
+ --num_workers=2 \
+ --lr_warmup_steps=10 \
+ --sample_n_frames=16 \
+ --learning_rate=1e-5 \
+ --per_gpu_batch_size=1 \
+ --num_train_epochs=10000 \
+ --mixed_precision="fp16" \
+ --gradient_accumulation_steps=1 \
+ --checkpointing_steps=2000 \
+ --validation_steps=500 \
+ --checkpoints_total_limit=5000 \
+ --resume_from_checkpoint="latest" \
+ --max_train_steps=30000
\ No newline at end of file
diff --git a/DWPose/skeleton_extraction.py b/DWPose/skeleton_extraction.py
index 2833d15..701931f 100644
--- a/DWPose/skeleton_extraction.py
+++ b/DWPose/skeleton_extraction.py
@@ -149,7 +149,8 @@ def get_video_pose(video_path, ref_image_path, poses_folder_path=None):
detected_poses = []
files = os.listdir(video_path)
png_files = [f for f in files if f.endswith('.png')]
- png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
+ png_files.sort(key=lambda x: int(x.split('.')[0]))
+# png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
for sub_name in png_files:
sub_driven_image_path = os.path.join(video_path, sub_name)
driven_image = cv2.imread(sub_driven_image_path)
diff --git a/README.md b/README.md
index f517651..d801317 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,4 @@
-# StableAnimator [CVPR2025]
-
-
+# Porked from StableAnimator [CVPR2025]
StableAnimator: High-Quality Identity-Preserving Human Image Animation
@@ -28,37 +26,6 @@ StableAnimator: High-Quality Identity-Preserving Human Image Animation
Comparison results between StableAnimator and state-of-the-art (SOTA) human image animation models highlight the superior performance of StableAnimator in delivering high-fidelity, identity-preserving human image animation.
-
-## Overview
-
-
-
-
- The overview of the framework of StableAnimator.
-
-
-Current diffusion models for human image animation struggle to ensure identity (ID) consistency. This paper presents StableAnimator, the first end-to-end ID-preserving video diffusion framework, which synthesizes high-quality videos without any post-processing, conditioned on a reference image and a sequence of poses. Building upon a video diffusion model, StableAnimator contains carefully designed modules for both training and inference striving for identity consistency. In particular, StableAnimator begins by computing image and face embeddings with off-the-shelf extractors, respectively and face embeddings are further refined by interacting with image embeddings using a global content-aware Face Encoder. Then, StableAnimator introduces a novel distribution-aware ID Adapter that prevents interference caused by temporal layers while preserving ID via alignment. During inference, we propose a novel Hamilton-Jacobi-Bellman (HJB) equation-based optimization to further enhance the face quality. We demonstrate that solving the HJB equation can be integrated into the diffusion denoising process, and the resulting solution constrains the denoising path and thus benefits ID preservation. Experiments on multiple benchmarks show the effectiveness of StableAnimator both qualitatively and quantitatively.
-
-## News
-* `[2025-3-10]`:๐ฅThe codes of HJB-based face optimization are released!
-* `[2025-2-27]`:๐ฅ StableAnimator is accepted by CVPR2025๐๐๐. The code of HJB-based face optimization will be released in March. Stay tuned!
-* `[2024-12-13]`:๐ฅ The training code and training tutorial are released! You can train/finetune your own StableAnimator on your own collected datasets! Other codes will be released very soon. Stay tuned!
-* `[2024-12-10]`:๐ฅ The gradio interface is released! Many thanks to [@gluttony-10](https://space.bilibili.com/893892) for his contribution! Other codes will be released very soon. Stay tuned!
-* `[2024-12-6]`:๐ฅ All data preprocessing codes (human skeleton extraction and human face mask extraction) are released! The training code and detailed training tutorial will be released before 2024.12.13. Stay tuned!
-* `[2024-12-4]`:๐ฅ We are thrilled to release an interesting dance demo (๐ฅ๐ฅAPT Dance๐ฅ๐ฅ)! The generated video can be seen on [YouTube](https://www.youtube.com/watch?v=KNPoAsWr_sk) and [Bilibili](https://www.bilibili.com/video/BV1KczXYhER7).
-* `[2024-11-28]`:๐ฅ The data pre-processing codes (human skeleton extraction) are available! Other codes will be released very soon. Stay tuned!
-* `[2024-11-26]`:๐ฅ The project page, code, technical report and [a basic model checkpoint](https://huggingface.co/FrancisRing/StableAnimator/tree/main) are released. Further training codes, data pre-processing codes, the evaluation dataset and StableAnimator-pro will be released very soon. Stay tuned!
-
-## To-Do List
-- [x] StableAnimator-basic
-- [x] Inference Code
-- [x] Evaluation Samples
-- [x] Data Pre-Processing Code (Skeleton Extraction)
-- [x] Data Pre-Processing Code (Human Face Mask Extraction)
-- [x] Training Code
-- [x] Inference Code with HJB-based Face Optimization
-- [ ] StableAnimator-pro
-
## Quickstart
For the basic version of the model checkpoint, it supports generating videos at a 576x1024 or 512x512 resolution. If you encounter insufficient memory issues, you can appropriately reduce the number of animated frames.
@@ -133,264 +100,3 @@ mv ./models/antelopev2/antelopev2 ./models/tmp
rm -rf ./models/antelopev2
mv ./models/tmp ./models/antelopev2
```
-
-### Evaluation Samples
-The evaluation samples presented in the paper can be downloaded from [OneDrive](https://1drv.ms/f/c/becb962aad1a1f95/EubdzCAI7BFLhJff2LrHkt8BC9mOiwJ5V67t-ypxRnCK4Q?e=ElEmcn) or `inference.zip` in checkpoints. Please download evaluation samples manually as follows:
-```
-cd StableAnimator
-mkdir inference
-```
-All the evaluation samples should be organized as follows:
-```
-inference/
-โโโ case-1
-โย ย โโโ poses
-โย ย โโโ faces
-โย ย โโโ reference.png
-โโโ case-2
-โย ย โโโ poses
-โย ย โโโ faces
-โย ย โโโ reference.png
-โโโ case-3
-โย ย โโโ poses
-โย ย โโโ faces
-โย ย โโโ reference.png
-```
-
-### Human Skeleton Extraction
-We leverage the pre-trained DWPose to extract the human skeletons. In the initialization of DWPose, the pretrained weights should be configured in `/DWPose/dwpose_utils/wholebody.py`:
-```
-onnx_det = 'path/checkpoints/DWPose/yolox_l.onnx'
-onnx_pose = 'path/checkpoints/DWPose/dw-ll_ucoco_384.onnx'
-```
-Given the target image folder containing multiple .png files, you can use the following command to obtain the corresponding human skeleton images:
-```
-python DWPose/skeleton_extraction.py --target_image_folder_path="path/test/target_images" --ref_image_path="path/test/reference.png" --poses_folder_path="path/test/poses"
-```
-It is worth noting that the .png files in the target image folder are named in the format `frame_i.png`, such as `frame_0.png`, `frame_1.png`, and so on.
-`--ref_image_path` refers to the path of the given reference image. The obtained human skeleton images are saved in `path/test/poses`. It is particularly significant that the target skeleton images should be aligned with the reference image regarding the body shape.
-
-If you only have the target MP4 file (target.mp4), we recommend you to use `ffmpeg` to convert the MP4 file to multiple frames (.png files) without any quality loss.
-```
-ffmpeg -i target.mp4 -q:v 1 -start_number 0 path/test/target_images/frame_%d.png
-```
-The obtained frames are saved in `path/test/target_images`.
-
-### Human Face Mask Extraction
-Given the path to an image folder containing multiple RGB `.png` files, you can run the following command to extract the corresponding human face masks:
-```
-python face_mask_extraction.py --image_folder="path/StableAnimator/inference/your_case/target_images"
-```
-`path/StableAnimator/inference/your_case/target_images` contains multiple `.png` files. The obtained masks are saved in `path/StableAnimator/inference/your_case/faces`.
-
-### Base Model inference
-A sample configuration for testing is provided as `command_basic_infer.sh`. You can also easily modify the various configurations according to your needs.
-
-```
-bash command_basic_infer.sh
-```
-StableAnimator supports human image animation at two different resolution settings: 512x512 and 576x1024. You can modify "--width" and "--height" in `command_basic_infer.sh` to set the resolution of the animation. "--output_dir" in `command_basic_infer.sh` refers to the saved path of the generated animation. "--validation_control_folder" and "--validation_image" in `command_basic_infer.sh` refer to the paths of the given pose sequence and the reference image, respectively.
-"--pretrained_model_name_or_path" in `command_basic_infer.sh` is the path of pretrained SVD. "posenet_model_name_or_path", "face_encoder_model_name_or_path", and "unet_model_name_or_path" in `command_basic_infer.sh` refer to paths of pretrained StableAnimator weights.
-If you have enough GPU resources, you can increase the value (4=>8=>16) of "--decode_chunk_size" in `command_basic_infer.sh` to promote the temporal smoothness of the animation.
-
-Tips: if your GPU memory is limited, you can reduce the number of animated frames. This command will generate two files: animated_images and animated_images.gif.
-If you want to obtain the high quality MP4 file, we recommend you to leverage ffmpeg on the animated_images as follows:
-```
-cd animated_images
-ffmpeg -framerate 20 -i frame_%d.png -c:v libx264 -crf 10 -pix_fmt yuv420p /path/animation.mp4
-```
-"-framerate" refers to the fps setting. "-crf" indicates the quality of the generated MP4 file, with smaller values corresponding to higher quality.
-Additionally, you can also run the following command to launch a Gradio interface:
-```
-python app.py
-```
-
-### Model inference with HJB-based Face Optimization
-A sample configuration for testing is provided as `command_op_infer.sh`. You can also easily modify the various configurations according to your needs.
-```
-bash command_op_infer.sh
-```
-`--num_optimization_iter`, `--start_refine_step`, and `--end_refine_step` refer to the epoch number of HJB-based face optimization at each timestep, the start time of the optimization, and the end time of the optimization, respectively.
-These three parameters need to be adaptively modified in certain situations based on the specific input videos and reference image.
-`--face_embedding_extractor_weight_path` can be downloaded from [HuggingFace](https://huggingface.co/FrancisRing/StableAnimator/tree/main/Animation).
-Notably, you should extract the corresponding face masks before conducting our JHB-based optimization. For more details about human face extraction, please refer to the Human Face Mask Extraction Section in the README file.
-
-### Model Training
-๐ฅItโs worth noting that if youโre looking to train a conditioned Stable Video Diffusion (SVD) model, this training tutorial will also be helpful.๐ฅ
-For the training dataset, it has to be organized as follows:
-
-```
-animation_data/
-โโโ rec
-โย ย โย ย โโโ00001
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โย ย โโโframe_0.png
-โย ย โย ย โย ย โย ย โโโframe_1.png
-โย ย โย ย โย ย โย ย โโโframe_2.png
-โย ย โย ย โย ย โย ย โโโ...
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โย ย โโโframe_0.png
-โย ย โย ย โย ย โย ย โโโframe_1.png
-โย ย โย ย โย ย โย ย โโโframe_2.png
-โย ย โย ย โย ย โย ย โโโ...
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โย ย โย ย โโโframe_0.png
-โย ย โย ย โย ย โย ย โโโframe_1.png
-โย ย โย ย โย ย โย ย โโโframe_2.png
-โย ย โย ย โย ย โย ย โโโ...
-โย ย โย ย โโโ00002
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โโโ00003
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โโโ...
-โโโ vec
-โย ย โย ย โโโ00001
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โโโ00002
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โโโ00003
-โย ย โย ย โย ย โโโimages
-โย ย โย ย โย ย โโโfaces
-โย ย โย ย โย ย โโโposes
-โย ย โย ย โโโ...
-โโโ video_rec_path.txt
-โโโ video_vec_path.txt
-```
-StableAnimator is trained on mixed-resolution videos, with 512x512 videos stored in `animation_data/rec` and 576x1024 videos stored in `animation_data/vec`. Each folder in `animation_data/rec` or `animation_data/vec` contains three subfolders which contains multiple `.png` image files.
-All `.png` image files are named in the format `frame_i.png`, such as `frame_0.png`, `frame_1.png`, and so on.
-`00001`, `00002`, `00003` indicate individual video information.
-In terms of three subfolders, `images`, `faces`, and `poses` store RGB frames, corresponding human face masks, and corresponding human skeleton poses, respectively.
-`video_rec_path.txt` and `video_vec_path.txt` record folder paths of `animation_data/rec` and `animation_data/vec`, respectively.
-For example, the content of `video_rec_path.txt` is shown as follows:
-```
-path/StableAnimator/animation_data/rec/00001
-path/StableAnimator/animation_data/rec/00002
-path/StableAnimator/animation_data/rec/00003
-path/StableAnimator/animation_data/rec/00004
-path/StableAnimator/animation_data/rec/00005
-path/StableAnimator/animation_data/rec/00006
-...
-```
-If you only have raw videos, you can leverage `ffmpeg` to extract frames from raw videos and store them in the subfolder `images`.
-```
-ffmpeg -i raw_video_1.mp4 -q:v 1 -start_number 0 path/StableAnimator/animation_data/rec/00001/images/frame_%d.png
-```
-The obtained frames are saved in `path/StableAnimator/animation_data/rec/00001/images`.
-
-For extracting the human skeleton poses, you can run the following command:
-```
-python DWPose/training_skeleton_extraction.py --root_path="path/StableAnimator/animation_data" --name="rec" --start=1 --end=500
-```
-`--root_path` and `--name` refer to the root path of training datasets and the name of the dataset.
-`--start` and `--end` specify the starting and ending indices of the selected training dataset. For example, `--name="rec" --start=1 --end=500` indicates that the skeleton extraction will start at `path/StableAnimator/animation_data/rec/00001` and end at `path/StableAnimator/animation_data/rec/00500`.
-
-For extraction details of corresponding face masks, please refer to the Human Face Mask Extraction section.
-When your dataset is organized exactly as outlined above, you can easily train your StableAnimator by running the following command:
-```
-bash command_train.sh
-```
-For the parameter details of `command_train.sh`, `CUDA_VISIBLE_DEVICES` refers to gpu devices. In my setting, I use 4 NVIDIA A100 80G to train StableAnimator (`CUDA_VISIBLE_DEVICES=3,2,1,0`).
-`--pretrained_model_name_or_path` and `--output_dir` refer to the pretrained SVD path and the checkpoint saved path of the trained StableAnimator.
-`--data_root_path`, `--rec_data_path`, and `--vec_data_path` are the root path of datasets, the path of `video_rec_path.txt`, and the path of `video_vec_path.txt`, respectively.
-`validation_image_folder`, `validation_control_folder`, and `validation_image` are paths of validation ground truths, validation driven skeleton poses, and the validation reference image.
-`--sample_n_frames` is the number of frames that StableAnimator processes in a single batch.
-`--num_train_epochs` is the training epoch number. It is worth noting that the default number of training epochs is set to infinite. You can manually terminate the training process once you observe that your StableAnimator has reached its peak performance.
-The overall file structure of StableAnimator at training is shown as follows:
-```
-StableAnimator/
-โโโ DWPose
-โโโ animation
-โโโ animation_data
-โย ย โโโ rec
-โย ย โโโ vec
-โย ย โโโ video_rec_path.txt
-โย ย โโโ video_vec_path.txt
-โโโ validation
-โย ย โโโ ground_truth
-โย ย โย โโโ frame_0.png
-โย ย โย โโโ frame_1.png
-โย ย โย โโโ frame_2.png
-โย ย โย โโโ ...
-โย ย โโโ poses
-โย ย โย โโโ frame_0.png
-โย ย โย โโโ frame_1.png
-โย ย โย โโโ frame_2.png
-โย ย โย โโโ ...
-โย ย โโโ reference.png
-โโโ checkpoints
-โย ย โโโ DWPose
-โย ย โย โโโ dw-ll_ucoco_384.onnx
-โย ย โย ย โโโ yolox_l.onnx
-โย ย โโโ Animation
-โย ย โย ย โโโ pose_net.pth
-โย ย โย ย โโโ face_encoder.pth
-โย ย โย ย โโโ unet.pth
-โย ย โโโ SVD
-โย ย โย ย โโโ feature_extractor
-โย ย โย ย โโโ image_encoder
-โย ย โย ย โโโ scheduler
-โย ย โย ย โโโ unet
-โย ย โย ย โโโ vae
-โย ย โย ย โโโ model_index.json
-โย ย โย ย โโโ svd_xt.safetensors
-โย ย โย ย โโโ svd_xt_image_decoder.safetensors
-โย ย โโโ inference.zip
-โโโ models
-โ โ โโโ antelopev2
-โย ย โย ย โโโ 1k3d68.onnx
-โย ย โย ย โโโ 2d106det.onnx
-โย ย โย ย โโโ genderage.onnx
-โย ย โย ย โโโ glintr100.onnx
-โย ย โย ย โโโ scrfd_10g_bnkps.onnx
-โโโ app.py
-โโโ command_basic_infer.sh
-โโโ inference_basic.py
-โโโ train.py
-โโโ command_train.sh
-โโโ requirement.txt
-```
-It is worth noting that training StableAnimator requires approximately 70GB of VRAM due to the mixed-resolution (512x512 and 576x1024) training pipeline.
-However, if you train StableAnimator exclusively on 512x512 videos, the VRAM requirement is reduced to approximately 40GB.
-Additionally, The backgrounds of the selected training videos should remain static, as this helps the diffusion model calculate accurate reconstruction loss.
-
-If you want to train StableAnimator on a single resolution, you can use the following command:
-```
-bash command_train_single.sh
-```
-You can customize the resolution by modifying `--dataset_width` and `--dataset_height`, both of which default to 512.
-
-Regarding finetuning StableAnimator, you can run the following command:
-```
-bash command_finetune.sh
-```
-`posenet_model_finetune_path`, `face_encoder_finetune_path`, and `unet_model_finetune_path` in `command_finetune.sh` refer to paths of pretrained StableAnimator weights.
-
-### VRAM requirement and Runtime
-
-For the 15s demo video (512x512, fps=30), the 16-frame basic model requires 8GB VRAM and finishes in 5 minutes on a 4090 GPU.
-
-The minimum VRAM requirement for the 16-frame U-Net of the pro model is 10GB (576x1024, fps=30); however, the VAE decoder demands 16GB. You have the option to run the VAE decoder on CPU.
-
-## Contact
-If you have any suggestions or find our work helpful, feel free to contact me
-
-Email: francisshuyuan@gmail.com
-
-If you find our work useful, please consider giving a star to this github repository and citing it:
-```bib
-@inproceedings{tu2025stableanimator,
- title={Stableanimator: High-quality identity-preserving human image animation},
- author={Tu, Shuyuan and Xing, Zhen and Han, Xintong and Cheng, Zhi-Qi and Dai, Qi and Luo, Chong and Wu, Zuxuan},
- booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
- pages={21096--21106},
- year={2025}
-}
-```
diff --git a/animation/dataset/animation_new_dataset.py b/animation/dataset/animation_new_dataset.py
new file mode 100644
index 0000000..1dfeefe
--- /dev/null
+++ b/animation/dataset/animation_new_dataset.py
@@ -0,0 +1,348 @@
+import os
+import os.path as osp
+import random
+import warnings
+
+import numpy as np
+import torch
+import cv2
+import h5py
+from PIL import Image
+from torch.utils.data.dataset import Dataset
+from einops import rearrange
+from animation.modules.face_model import FaceModel
+
+
+
+class LargeScaleMusicVideos(Dataset):
+ def __init__(self, root_path, txt_path, width, height, n_sample_frames, sample_frame_rate,
+ sample_margin=30, app=None, handler_ante=None, face_helper=None):
+ self.root_path = root_path
+ self.txt_path = txt_path
+ self.width = width
+ self.height = height
+ self.n_sample_frames = n_sample_frames
+ self.sample_frame_rate = sample_frame_rate
+ self.sample_margin = sample_margin
+
+ self.video_files = self._read_txt_file_images()
+
+ self.app = app
+ self.handler_ante = handler_ante
+ self.face_helper = face_helper
+
+ def _read_txt_file_images(self):
+ with open(self.txt_path, 'r') as file:
+ video_files = [line.strip() for line in file if line.strip()]
+ return video_files
+
+ def __len__(self):
+ return len(self.video_files)
+
+ def frame_count(self, frames_path):
+ files = os.listdir(frames_path)
+ image_files = [f for f in files if f.endswith(('.png', '.jpg'))]
+ return len(image_files)
+
+ def find_frames_list(self, frames_path):
+ files = os.listdir(frames_path)
+ image_files = [f for f in files if f.endswith(('.png', '.jpg'))]
+
+ if image_files and image_files[0].startswith('frame_'):
+ image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
+ else:
+ image_files.sort(key=lambda x: int(x.split('.')[0]))
+ return image_files
+
+
+ def get_face_masks(self, pil_img):
+ rgb_image = np.array(pil_img)
+ bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
+ image_info = self.app.get(bgr_image)
+ mask = np.zeros((self.height, self.width), dtype=np.uint8)
+
+ if len(image_info) > 0:
+ for info in image_info:
+ x_1 = info['bbox'][0]
+ y_1 = info['bbox'][1]
+ x_2 = info['bbox'][2]
+ y_2 = info['bbox'][3]
+ cv2.rectangle(mask, (int(x_1), int(y_1)), (int(x_2), int(y_2)), (255), thickness=cv2.FILLED)
+ mask = mask.astype(np.float64) / 255.0
+ else:
+ self.face_helper.clean_all()
+ with torch.no_grad():
+ bboxes = self.face_helper.face_det.detect_faces(bgr_image, 0.97)
+ if len(bboxes) > 0:
+ for bbox in bboxes:
+ cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255),
+ thickness=cv2.FILLED)
+ mask = mask.astype(np.float64) / 255.0
+ else:
+ mask = np.ones((self.height, self.width), dtype=np.uint8)
+ return mask
+
+ def resize_and_center_crop(self, img, target_width, target_height):
+ """์ด๋ฏธ์ง๋ฅผ ๋น์จ ์ ์งํ๋ฉฐ ๋ฆฌ์ฌ์ด์ฆ ํ center crop"""
+ width, height = img.size
+
+ # ํ๊ฒ ์ฌ์ด์ฆ๋ณด๋ค ์์ผ๋ฉด ๊ทธ๋ฅ ๋ฆฌ์ฌ์ด์ฆ
+ if width <= target_width and height <= target_height:
+ return img.resize((target_width, target_height), Image.LANCZOS)
+
+ # ๋น์จ ๊ณ์ฐ (์งง์ ์ชฝ์ ํ๊ฒ์ ๋ง์ถค)
+ scale = max(target_width / width, target_height / height)
+ new_width = int(width * scale)
+ new_height = int(height * scale)
+
+ # ๋ฆฌ์ฌ์ด์ฆ
+ img = img.resize((new_width, new_height), Image.LANCZOS)
+
+ # Center crop
+ left = (new_width - target_width) // 2
+ top = (new_height - target_height) // 2
+ right = left + target_width
+ bottom = top + target_height
+
+ return img.crop((left, top, right, bottom))
+
+ def __getitem__(self, idx):
+ try:
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ warnings.filterwarnings('ignore', category=FutureWarning)
+
+ base_path = '/root/aist_hdf5/rec/'
+ video_base_path = self.video_files[idx]
+ frames_path = osp.join(video_base_path, "images")
+ #poses_path = osp.join(video_base_path, "poses")
+ face_masks_path = osp.join(video_base_path, "faces")
+
+ video_length = self.frame_count(frames_path)
+ frames_list = self.find_frames_list(frames_path)
+
+ # Music feature ๋ก๋
+ mus_path = video_base_path.split('/')[-1] + '.h5'
+ mus_path = osp.join(base_path, mus_path)
+
+ music_fea = torch.zeros((video_length, 4800), dtype=torch.float32)
+ if osp.exists(mus_path):
+ try:
+ with h5py.File(mus_path, "r") as f:
+ m = f["music"][:]
+
+ # Shape ์ ๋ฆฌ
+ if m.ndim == 3 and m.shape[0] == 1:
+ m = m.squeeze(0) # (1, T, 4800) -> (T, 4800)
+
+ if m.ndim == 2:
+ # (T, 4800) ํํ๊ฐ ๋๋๋ก ํ์ธ
+ if m.shape[1] != 4800 and m.shape[0] == 4800:
+ m = m.T # (4800, T) -> (T, 4800)
+ # (720, 4800) ํํ๋ผ๋ฉด ๊ทธ๋๋ก ์ ์ง (์๊ฐ์ถ์ด 720)
+
+ music_fea = torch.from_numpy(m).float()
+ # print(f"H5 loaded: music={music_fea.shape}")
+ except Exception as e:
+ print(f"Failed to load H5 file: {mus_path}, error: {e}")
+
+ # ํด๋ฆฝ ๊ธธ์ด ๊ณ์ฐ
+ clip_length = min(video_length, (self.n_sample_frames - 1) * self.sample_frame_rate + 1)
+ start_idx = random.randint(0, video_length - clip_length)
+ batch_index = np.linspace(
+ start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
+ ).tolist()
+
+ # ์ฐธ์กฐ ํ๋ ์ ์ ํ
+ all_indices = list(range(video_length))
+ available_indices = [i for i in all_indices if i not in batch_index]
+
+ if available_indices:
+ reference_frame_idx = random.choice(available_indices)
+ else:
+ extreme_sample_frame_rate = 2
+ extreme_clip_length = min(video_length, (self.n_sample_frames - 1) * extreme_sample_frame_rate + 1)
+ extreme_start_idx = random.randint(0, video_length - extreme_clip_length)
+ extreme_batch_index = np.linspace(
+ extreme_start_idx, extreme_start_idx + extreme_clip_length - 1,
+ self.n_sample_frames, dtype=int
+ ).tolist()
+ extreme_available_indices = [i for i in all_indices if i not in extreme_batch_index]
+
+ if extreme_available_indices:
+ reference_frame_idx = random.choice(extreme_available_indices)
+ else:
+ raise ValueError(f"No available reference frame in {frames_path}")
+
+ # ์ฐธ์กฐ ํ๋ ์ ๋ก๋
+ reference_frame_path = osp.join(frames_path, frames_list[reference_frame_idx])
+ reference_pil_image = Image.open(reference_frame_path).convert('RGB')
+ reference_pil_image = self.resize_and_center_crop(reference_pil_image, self.width, self.height)
+ reference_pil_image = torch.from_numpy(np.array(reference_pil_image)).float()
+ reference_pil_image = reference_pil_image / 127.5 - 1
+
+ reference_frame_face_pil = Image.open(reference_frame_path).convert('RGB')
+ reference_frame_face_pil = self.resize_and_center_crop(reference_frame_face_pil, self.width, self.height)
+ reference_frame_face = np.array(reference_frame_face_pil)
+ reference_frame_face_bgr = cv2.cvtColor(reference_frame_face, cv2.COLOR_RGB2BGR)
+ reference_frame_face_info = self.app.get(reference_frame_face_bgr)
+ if len(reference_frame_face_info) > 0:
+ reference_frame_face_info = sorted(
+ reference_frame_face_info,
+ key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1])
+ )[-1]
+ reference_frame_id_ante_embedding = reference_frame_face_info['embedding']
+ else:
+ reference_frame_id_ante_embedding = None
+
+ # ํ๊ฒ ํ๋ ์๋ค ๋ก๋
+ # pose_pil_image_list = []
+ tgt_pil_image_list = []
+ tgt_face_masks_list = []
+
+ for index in batch_index:
+ tgt_img_path = osp.join(frames_path, frames_list[index])
+ img_basename = osp.splitext(osp.basename(tgt_img_path))[0]
+ frame_number = int(img_basename.lstrip('0')) if img_basename != '0' else 0
+
+ # pose_name = f"frame_{frame_number}.png"
+ # pose_path = osp.join(poses_path, pose_name)
+
+ face_name = img_basename + '.png'
+ face_mask_path = osp.join(face_masks_path, face_name)
+
+ # ํ๊ฒ ์ด๋ฏธ์ง
+ try:
+ tgt_img_pil = Image.open(tgt_img_path).convert('RGB')
+ tgt_img_pil = self.resize_and_center_crop(tgt_img_pil, self.width, self.height)
+ tgt_img_tensor = torch.from_numpy(np.array(tgt_img_pil)).float()
+ tgt_img_normalized = tgt_img_tensor / 127.5 - 1
+ tgt_pil_image_list.append(tgt_img_normalized)
+ except Exception as e:
+ print(f"Failed loading image: {tgt_img_path}, error: {e}")
+ tgt_pil_image_list.append(torch.zeros(self.height, self.width, 3))
+
+ # Pose
+ # try:
+ # pose = Image.open(pose_path).convert('RGB')
+ # pose = self.resize_and_center_crop(pose, self.width, self.height)
+ # pose = torch.from_numpy(np.array(pose)).float()
+ # pose = pose / 127.5 - 1
+ # except Exception as e:
+ # print(f"Failed loading pose: {pose_path}, error: {e}")
+ # pose = torch.zeros(self.height, self.width, 3)
+ # pose_pil_image_list.append(pose)
+
+ # Face mask
+ try:
+ face = Image.open(face_mask_path).convert('L')
+ face = self.resize_and_center_crop(face, self.width, self.height)
+ face = torch.from_numpy(np.array(face)).float()
+ face = face / 255.0
+ face = face.unsqueeze(0) # โ (H, W) -> (1, H, W)
+ except Exception as e:
+ print(f"Failed loading face: {face_mask_path}, error: {e}")
+ face = torch.zeros(1, self.height, self.width)
+ tgt_face_masks_list.append(face)
+
+ # Music feature ์ ํ
+ music_selected = music_fea[batch_index] # (n_sample_frames, 4800)
+
+ # ํ
์ ๋ณํ ๋ฐ ์ฐจ์ ์ฌ๋ฐฐ์ด
+ tgt_pil_image_list = torch.stack(tgt_pil_image_list, dim=0)
+ #pose_pil_image_list = torch.stack(pose_pil_image_list, dim=0)
+ tgt_face_masks_list = torch.stack(tgt_face_masks_list, dim=0)
+
+ # (F, H, W, C) -> (F, C, H, W)
+ tgt_pil_image_list = rearrange(tgt_pil_image_list, "f h w c -> f c h w")
+ reference_pil_image = rearrange(reference_pil_image, "h w c -> c h w")
+ #pose_pil_image_list = rearrange(pose_pil_image_list, "f h w c -> f c h w")
+ #tgt_face_masks_list = rearrange(tgt_face_masks_list, "f h w c -> f c h w")
+
+ sample = dict(
+ pixel_values=tgt_pil_image_list, # (F, 3, H, W)
+ reference_image=reference_pil_image, # (3, H, W)
+ #pose_pixels=pose_pil_image_list, # (F, 3, H, W)
+ faceid_embeds=reference_frame_id_ante_embedding, # None
+ tgt_face_masks=tgt_face_masks_list, # (F, 3, H, W) # (H, W)
+ music_fea=music_selected, # (F, 4800)
+ )
+ if (
+ sample["pixel_values"] is None
+ or sample["reference_image"] is None
+ or sample["faceid_embeds"] is None
+ or sample["tgt_face_masks"] is None
+ or sample["music_fea"] is None
+ ):
+ raise ValueError(f"Invalid sample at idx={idx}, path={self.video_files[idx]}")
+
+ return sample
+ except Exception as e:
+ print(f"[WARN] Skipped sample idx={idx}, path={self.video_files[idx]}, error={e}")
+ return None
+
+
+if __name__ == "__main__":
+ print("=" * 70)
+ print("Testing LargeScaleMusicVideos Dataset")
+ print("=" * 70)
+ face_model = FaceModel()
+ dataset = LargeScaleMusicVideos(
+ root_path="/root/aist_hdf5/rec",
+ txt_path="/root/aist_hdf5/full_list.txt",
+ width=512,
+ height=512,
+ n_sample_frames=16,
+ sample_frame_rate=2,
+ app=face_model.app,
+ handler_ante=face_model.handler_ante,
+ face_helper=face_model.face_helper
+ )
+
+ print(f"\nDataset Info:")
+ print(f" - Total videos: {len(dataset)}")
+ print(f" - Sample frames: {dataset.n_sample_frames}")
+ print(f" - Frame rate: {dataset.sample_frame_rate}")
+ print(f" - Output size: {dataset.width}x{dataset.height}")
+
+ print("\n" + "=" * 70)
+ print("Loading first sample...")
+ print("=" * 70)
+
+ errors = []
+ for idx in range(len(dataset)):
+ print(f"\n[{idx+1}/{len(dataset)}] Loading: {dataset.video_files[idx]}")
+ try:
+ sample = dataset[idx]
+ # ๊ฐ key๋ณ shape/type/range ์ถ๋ ฅ
+ for key, value in sample.items():
+ if isinstance(value, torch.Tensor):
+ print(f" [{key}] - shape: {tuple(value.shape)}, dtype: {value.dtype}, "
+ f"range: [{value.min():.3f}, {value.max():.3f}], mean: {value.mean():.3f}")
+ elif isinstance(value, np.ndarray):
+ print(f" [{key}] - shape: {value.shape}, dtype: {value.dtype}")
+ else:
+ print(f" [{key}] - type: {type(value)}")
+
+ # validation ์ฒดํฌ
+ assert sample['pixel_values'].shape[0] == dataset.n_sample_frames
+ assert sample['music_fea'].shape[0] == dataset.n_sample_frames
+ assert sample['music_fea'].shape[1] == 4800
+
+ print(" -> Validation passed")
+
+ except Exception as e:
+ print(f" !! Error in {dataset.video_files[idx]}: {e}")
+ import traceback
+ traceback.print_exc()
+ errors.append((dataset.video_files[idx], str(e)))
+
+ print("\n" + "=" * 70)
+ print("FINISHED VALIDATION")
+ print("=" * 70)
+
+ if errors:
+ print(f"\n{len(errors)} errors found:")
+ for f, msg in errors:
+ print(f" - {f}: {msg}")
+ else:
+ print("\nAll samples loaded and validated successfully!")
\ No newline at end of file
diff --git a/animation/modules/music_encoder.py b/animation/modules/music_encoder.py
new file mode 100644
index 0000000..14762e2
--- /dev/null
+++ b/animation/modules/music_encoder.py
@@ -0,0 +1,60 @@
+from pathlib import Path
+
+import einops
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from diffusers.models.modeling_utils import ModelMixin
+
+class MusicEncoder(ModelMixin):
+ def __init__(self, indim=4800, hw=64, noise_latent_channels=320):
+ super().__init__()
+ self.hw = hw
+ self.noise_latent_channels = noise_latent_channels
+ self.latent_dim = hw * hw
+
+ # projection to 64x64
+ self.net = nn.Linear(indim, self.latent_dim)
+
+ # 1โnoise_latent_channels ์ฑ๋ ํ์ฅ
+ self.expand = nn.Conv2d(1, noise_latent_channels, kernel_size=3, padding=1)
+
+ def forward(self, x): # (B, T, 4800)
+ B, T, F = x.shape
+ z = self.net(x.view(B*T, F)) # (B*T, 4096)
+ z = z.view(B*T, 1, self.hw, self.hw) # (B*T, 1, 64, 64)
+ z = self.expand(z) # (B*T, noise_latent_channels, 64, 64)
+ return z
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_path, latent_dim=1024):
+ """Load pretrained music encoder weights"""
+ if not Path(pretrained_model_path).exists():
+ raise FileNotFoundError(f"No model file at {pretrained_model_path}")
+
+ print(f"Loading MusicEncoder from {pretrained_model_path}")
+
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
+
+ # Pretrained ๋ชจ๋ธ์ latent_dim ํ์ธ
+ pretrained_latent_dim = state_dict['net.4.weight'].shape[0]
+
+ model = cls(in_dim=4800, latent_dim=pretrained_latent_dim)
+ model.load_state_dict(state_dict, strict=True)
+
+ # latent_dim์ด ๋ค๋ฅด๋ฉด projection layer ์ถ๊ฐ
+ if pretrained_latent_dim != latent_dim:
+ print(f"Adding projection layer: {pretrained_latent_dim} -> {latent_dim}")
+ model.projection = nn.Linear(pretrained_latent_dim, latent_dim)
+ model.latent_dim = latent_dim
+
+ return model
+
+ def forward_with_projection(self, x):
+ """Projection layer๊ฐ ์์ ๋ ์ฌ์ฉ"""
+ z = self.forward(x) # (B, T, pretrained_latent_dim)
+ if hasattr(self, 'projection'):
+ z = self.projection(z.view(-1, z.shape[-1])) # (B*T, latent_dim)
+ z = z.view(x.shape[0], x.shape[1], -1) # (B, T, latent_dim)
+ return z
diff --git a/animation/modules/new.py b/animation/modules/new.py
new file mode 100644
index 0000000..2ad622d
--- /dev/null
+++ b/animation/modules/new.py
@@ -0,0 +1 @@
+import os import math import argparse import numpy as np from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from einops import rearrange, repeat import imageio.v3 as iio import os import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler import os, glob, numpy as np, torch from torch.utils.data import Dataset from PIL import Image import torchvision.transforms.functional as TF def unwrap_ddp(model): return model.module if hasattr(model, "module") else model def save_model_only(path, model): os.makedirs(os.path.dirname(path), exist_ok=True) state = unwrap_ddp(model).state_dict() torch.save(state, path) print(f"[CKPT] saved model weights -> {path}") def setup_ddp(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) return True, rank, local_rank, world_size else: return False, 0, 0, 1 def cleanup_ddp(): if dist.is_initialized(): dist.barrier() dist.destroy_process_group() def is_main_process(rank): return rank == 0 # --------------------------- # Utilities # --------------------------- def exists(x): return x is not None def default(val, d): return val if exists(val) else d def sinusoidal_time_embedding(timesteps, dim): """ timesteps: (B,) int64 return: (B, dim) """ device = timesteps.device half = dim // 2 freqs = torch.exp( torch.arange(half, device=device, dtype=torch.float32) * (-math.log(10000.0) / (half - 1)) ) args = timesteps.float()[:, None] * freqs[None, :] emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) if dim % 2 == 1: emb = F.pad(emb, (0,1)) return emb # --------------------------- # FiLM Conditioning Blocks # --------------------------- class FiLM(nn.Module): """ Feature-wise Linear Modulation cond -> (gamma, beta) -> y = x * (1 + gamma) + beta """ def __init__(self, cond_dim, channels): super().__init__() self.to_gamma_beta = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, channels * 2) ) def forward(self, x, cond): """ x: (B, C, T, H, W) cond: (B, T, cond_dim) """ B, C, T, H, W = x.shape cond = self.to_gamma_beta(cond) # (B, T, 2C) cond = cond.view(B, T, 2, C).permute(0,2,3,1) # (B,2,C,T) gamma, beta = cond[:,0], cond[:,1] # (B,C,T) gamma = gamma.unsqueeze(-1).unsqueeze(-1) # (B,C,T,1,1) beta = beta.unsqueeze(-1).unsqueeze(-1) # (B,C,T,1,1) return x * (1 + gamma) + beta # --------------------------- # 3D UNet Blocks # --------------------------- class ResBlock3D(nn.Module): def __init__(self, in_ch, out_ch, cond_dim=None): super().__init__() self.in_ch = in_ch self.out_ch = out_ch self.norm1 = nn.GroupNorm(8, in_ch) self.act1 = nn.SiLU() self.conv1 = nn.Conv3d(in_ch, out_ch, 3, padding=1) self.norm2 = nn.GroupNorm(8, out_ch) self.act2 = nn.SiLU() self.conv2 = nn.Conv3d(out_ch, out_ch, 3, padding=1) self.skip = nn.Conv3d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity() # time + cond FiLM self.has_cond = cond_dim is not None self.to_film_t = nn.Sequential(nn.SiLU(), nn.Linear(256, out_ch*2)) # time emb -> (gamma,beta) self.to_film_c = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, out_ch*2)) if self.has_cond else None def forward(self, x, t_emb, c_emb=None): """ x: (B,C,T,H,W) t_emb: (B,256) c_emb: (B,T,cond_dim) or None """ h = self.norm1(x) h = self.act1(h) h = self.conv1(h) # apply FiLM from time embedding (global, broadcast to T,H,W) B, C, T, H, W = h.shape film_t = self.to_film_t(t_emb).view(B, 2, C, 1, 1, 1) # (B,2,C,1,1,1) gamma_t, beta_t = film_t[:,0], film_t[:,1] h = h * (1 + gamma_t) + beta_t # apply FiLM from condition embedding (frame-wise) if self.has_cond and exists(c_emb): film_c = self.to_film_c(c_emb) # (B,T,2C) film_c = film_c.view(B, T, 2, C).permute(0,2,3,1) # (B,2,C,T) gamma_c, beta_c = film_c[:,0], film_c[:,1] # (B,C,T) gamma_c = gamma_c.unsqueeze(-1).unsqueeze(-1) # (B,C,T,1,1) beta_c = beta_c.unsqueeze(-1).unsqueeze(-1) h = h * (1 + gamma_c) + beta_c h = self.norm2(h) h = self.act2(h) h = self.conv2(h) return h + self.skip(x) class Down3D(nn.Module): def __init__(self, ch): super().__init__() # ์๊ฐ์ถ์ ์ ์ง, ๊ณต๊ฐ๋ง downsample (2x) self.op = nn.Conv3d(ch, ch, (1,4,4), stride=(1,2,2), padding=(0,1,1)) def forward(self, x): return self.op(x) class Up3D(nn.Module): def __init__(self, ch): super().__init__() self.op = nn.ConvTranspose3d(ch, ch, (1,4,4), stride=(1,2,2), padding=(0,1,1), output_padding=(0,0,0)) def forward(self, x): return self.op(x) class SpatioTemporalUNet(nn.Module): """ ์์ 3D U-Net: ์๊ฐ์ถ(T=75)์ ์ ์ง, ๊ณต๊ฐ(H,W)๋ง ํผ๋ผ๋ฏธ๋. - ๋ ํผ๋ฐ์ค ์ด๋ฏธ์ง(640x640) -> ์์ CNN์ผ๋ก ๊ธ๋ก๋ฒ ์๋ฒ ๋ฉ์ ๋ฝ์ cond์ ํฉ์นจ - ํ๋ ์๋ณ cond (B,T,512) + ref_emb -> FiLM ์ฃผ์
""" def __init__(self, in_ch=3, base_ch=64, depth=3, cond_dim=512, t_emb_dim=256): super().__init__() self.in_ch = in_ch self.base_ch = base_ch self.depth = depth self.cond_dim = cond_dim self.t_emb_dim = t_emb_dim # ์๊ฐ ์๋ฒ ๋ฉ projector self.time_mlp = nn.Sequential( nn.Linear(t_emb_dim, t_emb_dim*4), nn.SiLU(), nn.Linear(t_emb_dim*4, t_emb_dim), ) # ๋ ํผ๋ฐ์ค ์ด๋ฏธ์ง encoder (์๊ฒ) self.ref_enc = nn.Sequential( nn.Conv2d(22, 32, 7, stride=2, padding=3), # 320x320 nn.SiLU(), nn.Conv2d(32, 64, 5, stride=2, padding=2), # 160x160 nn.SiLU(), nn.Conv2d(64, 128, 5, stride=2, padding=2),# 80x80 nn.SiLU(), nn.AdaptiveAvgPool2d(1), ) self.ref_proj = nn.Linear(128, cond_dim) # input conv (3D) self.in_conv = nn.Conv3d(in_ch, base_ch, 3, padding=1) # encoder chs = [base_ch] self.downs = nn.ModuleList() self.down_res = nn.ModuleList() ch = base_ch for i in range(depth): self.down_res.append(ResBlock3D(ch, ch, cond_dim=cond_dim)) self.downs.append(Down3D(ch)) chs.append(ch) mid_ch = ch # middle self.mid1 = ResBlock3D(mid_ch, mid_ch, cond_dim=cond_dim) self.mid2 = ResBlock3D(mid_ch, mid_ch, cond_dim=cond_dim) # decoder self.ups = nn.ModuleList() self.up_res = nn.ModuleList() for i in range(depth): self.ups.append(Up3D(ch)) enc_ch = chs[-(i+2)] # skip channels self.up_res.append(ResBlock3D(ch + enc_ch, enc_ch, cond_dim=cond_dim)) ch = enc_ch self.out_norm = nn.GroupNorm(8, ch) self.out_act = nn.SiLU() self.out_conv = nn.Conv3d(ch, in_ch, 3, padding=1) self.cond_pro = nn.Linear(4800, cond_dim) def forward(self, x, t, cond_seq, ref_img): """ x: (B,3,T,H,W) noisy video t: (B,) int64 timestep cond_seq: (B,T,512) ref_img: (B,3,H,W) """ B, C, T, H, W = x.shape # time emb t_emb = sinusoidal_time_embedding(t, self.t_emb_dim) t_emb = self.time_mlp(t_emb) # (B,256) # reference image -> global cond r = self.ref_enc(ref_img) # (B,128,1,1) r = r.view(B, 128) r = self.ref_proj(r) # (B,512) # merge global ref cond into frame-wise cond # (B,T,512) + (B,1,512) cond_seq = self.cond_pro(cond_seq) cond = cond_seq + r.unsqueeze(1) # UNet h = self.in_conv(x) skips = [] for res, down in zip(self.down_res, self.downs): h = res(h, t_emb, cond) skips.append(h) h = down(h) h = self.mid1(h, t_emb, cond) h = self.mid2(h, t_emb, cond) for up, res, skip in zip(self.ups, self.up_res, reversed(skips)): h = up(h) # match T,H,W (should match T by design; spatial dims may be off by 1) if h.shape[-1] != skip.shape[-1] or h.shape[-2] != skip.shape[-2]: h = F.interpolate(h, size=skip.shape[-2:], mode='trilinear', align_corners=False) h = torch.cat([h, skip], dim=1) h = res(h, t_emb, cond) h = self.out_norm(h) h = self.out_act(h) return self.out_conv(h) # predict noise ฮต # --------------------------- # Simple DDPM Scheduler # --------------------------- @dataclass class DDPMConfig: num_train_timesteps: int = 1000 beta_start: float = 1e-4 beta_end: float = 0.02 beta_schedule: str = "linear" # 'linear' | 'cosine' class DDPMSchedulerSimple: def __init__(self, cfg: DDPMConfig): self.cfg = cfg if getattr(cfg, 'beta_schedule', 'linear') == 'cosine': # Improved DDPM cosine schedule (Nichol & Dhariwal) T = cfg.num_train_timesteps s = 0.008 steps = torch.arange(T + 1, dtype=torch.float64) f = torch.cos(((steps / T + s) / (1 + s)) * math.pi / 2) ** 2 a_bar = (f / f[0]).clamp(min=1e-9) betas = 1 - (a_bar[1:] / a_bar[:-1]) betas = betas.clamp(1e-8, 0.999) self.betas = betas.float() else: self.betas = torch.linspace(cfg.beta_start, cfg.beta_end, cfg.num_train_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1,0), value=1.0) def to(self, device): self.betas = self.betas.to(device) self.alphas = self.alphas.to(device) self.alphas_cumprod = self.alphas_cumprod.to(device) self.alphas_cumprod_prev = self.alphas_cumprod_prev.to(device) return self def add_noise(self, x0, noise, t): """ q(x_t | x_0) = sqrt(a_bar_t) * x0 + sqrt(1 - a_bar_t) * noise x0, noise: (B,3,T,H,W) t: (B,) """ a_bar = self.alphas_cumprod[t].view(-1,1,1,1,1) return torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise def step(self, model_pred_eps, t, x_t): """ One reverse step p(x_{t-1} | x_t) Predict x0 from eps, then compute mean of posterior. """ beta_t = self.betas[t].view(-1,1,1,1,1) a_t = self.alphas[t].view(-1,1,1,1,1) abar_t = self.alphas_cumprod[t].view(-1,1,1,1,1) abar_prev = self.alphas_cumprod_prev[t].view(-1,1,1,1,1) # estimate x0 x0 = (x_t - torch.sqrt(1 - abar_t) * model_pred_eps) / torch.sqrt(abar_t) # posterior q(x_{t-1} | x_t, x0) mean = ( torch.sqrt(abar_prev) * beta_t / (1 - abar_t) * x0 + torch.sqrt(a_t) * (1 - abar_prev) / (1 - abar_t) * x_t ) # add noise with posterior variance (beta_tilde) except for t==0 # beta_tilde_t = beta_t * (1 - abar_prev) / (1 - abar_t) noise = torch.randn_like(x_t) nonzero_mask = (t > 0).float().view(-1,1,1,1,1) var = beta_t * (1.0 - abar_prev) / (1.0 - abar_t) x_prev = mean + nonzero_mask * torch.sqrt(var) * noise return x_prev # --------------------------- # Toy Dataset (replace with real) # --------------------------- import os import glob import numpy as np from PIL import Image import torch from torch.utils.data import Dataset import torchvision.transforms.functional as TF import os, glob, numpy as np, torch from torch.utils.data import Dataset from PIL import Image import torchvision.transforms.functional as TF class VideoMusStrideDataset(Dataset): """ ๊ฐ ์ํ id=A: - PNG: png_crop/A/*.png ์์ ๊ฐ์ ๋๋ค ์์์ s, stride=2๋ก T๊ฐ ์ ํ - MUS: mus/A.npy (shape: (N, D))์์ ๋์ผ ์ธ๋ฑ์ค s, s+2, ... ๋ก TรD ์ ํ - REF: png_sam_crop/A__0100_seg.npy ๋ฅผ (3,H,W)๋ก ๋ณํํด [-1,1] ๋ฐํ: video: (3, T, H, W) in [-1,1] cond : (T, D) # D๋ npy์ ๋ ๋ฒ์งธ ์ฐจ์ (์ฌ๊ธฐ์ 4800) ref : (3, H, W) in [-1,1] """ def __init__( self, root, frames=75, size=640, stride=2, ref_source="seg", # 'seg' or 'first_frame' or 'sapiens' ids=None, # None์ด๋ฉด mus/*.npy ๊ธฐ์ค ์๋ ์์ง png_dirname="png_crop", mus_dirname="mus", seg_dirname="png_sam_crop", strict_length=False, # True๋ฉด ๊ธธ์ด ๋ถ์กฑ ์ ์ค๋ฅ seed=1234, # ๋๋ค ์์ ์ธ๋ฑ์ค ์ฌํ์ฉ # Sapiens integration (optional) sapiens_config: str | None = None, sapiens_checkpoint: str | None = None, sapiens_device: str = "cuda:0", ): super().__init__() self.root = root self.frames = frames self.size = size self.stride = stride self.ref_source = ref_source self.png_dir = os.path.join(root, png_dirname) self.mus_dir = os.path.join(root, mus_dirname) self.seg_dir = os.path.join(root, seg_dirname) self.strict_length = strict_length self.seed = seed # Sapiens self.sapiens = None if self.ref_source == "sapiens": if sapiens_config is None or sapiens_checkpoint is None: raise ValueError("ref_source='sapiens' requires sapiens_config and sapiens_checkpoint") try: from animator_with_music.segmentation.sapiens_body_seg import SapiensBodySegmenter self.sapiens = SapiensBodySegmenter( config_path=sapiens_config, checkpoint_path=sapiens_checkpoint, device=sapiens_device, ) if not self.sapiens.ready: raise RuntimeError("SapiensBodySegmenter not ready") except Exception as e: raise RuntimeError(f"Failed to initialize SapiensBodySegmenter: {e}") # id ์์ง (mus/*.npy ํ์ผ๋ช
๊ธฐ์ค) if ids is None: paths = sorted(glob.glob(os.path.join(self.mus_dir, "*.npy"))) self.ids = [os.path.splitext(os.path.basename(p))[0] for p in paths] else: self.ids = list(ids) if not self.ids: raise RuntimeError("No ids found under mus/*.npy") def __len__(self): return len(self.ids) # ---------------- index ์ ํ(๊ณตํต) ---------------- def _choose_indices(self, num_pngs, num_mus, rng): """ PNG ๊ฐ์(num_pngs)์ MUS ๊ธธ์ด(num_mus)๋ฅผ ๊ณ ๋ คํด ๊ฐ์ ๋๋ค ์์์ ์์ stride=2๋ก frames๊ฐ ์ธ๋ฑ์ค๋ฅผ ๋ฐํ. """ need = (self.frames - 1) * self.stride + 1 max_len = min(num_pngs, num_mus) if max_len >= need: max_start = max_len - need start = int(rng.integers(0, max_start + 1)) idxs = [start + i * self.stride for i in range(self.frames)] else: # ๊ธธ์ด ๋ถ์กฑ if self.strict_length: raise ValueError(f"Not enough length (png={num_pngs}, mus={num_mus}) for frames={self.frames}, stride={self.stride}") # ๊ฐ๋ฅํ ๋งํผ ๋ง๋ค๊ณ ๋ง์ง๋ง ์ธ๋ฑ์ค ๋ฐ๋ณต์ผ๋ก ํจ๋ start = 0 idxs = [start + i * self.stride for i in range(max(1, (max_len - 1)//self.stride + 1))] idxs = idxs[:self.frames] while len(idxs) < self.frames: idxs.append(idxs[-1]) return idxs # ---------------- ๋ก๋๋ค ---------------- def _load_png_frames(self, A, idxs): frame_dir = os.path.join(self.png_dir, A) pngs = sorted(glob.glob(os.path.join(frame_dir, "*.png"))) if not pngs: raise FileNotFoundError(f"No PNG frames in {frame_dir}") imgs = [] last_valid = len(pngs) - 1 for i in idxs: j = min(i, last_valid) # ๋ฒ์ ์ด๊ณผ์ ๋ง์ง๋ง ํ๋ ์ ์ฌ์ฉ img = Image.open(pngs[j]).convert("RGB") if img.size != (self.size, self.size): img = img.resize((self.size, self.size), Image.BICUBIC) img_t = TF.to_tensor(img) # [0,1] img_t = img_t * 2.0 - 1.0 # [-1,1] imgs.append(img_t) video = torch.stack(imgs, dim=0).permute(1, 0, 2, 3) # (3,T,H,W) return video def _load_mus_cond(self, A, idxs): mus_path = os.path.join(self.mus_dir, f"{A}.npy") if not os.path.isfile(mus_path): raise FileNotFoundError(f"Missing mus file: {mus_path}") arr = np.load(mus_path) # Support (N,D) or (1,N,D) if arr.ndim == 3: _, N, D = arr.shape idxs_adj = [min(i, N - 1) for i in idxs] cond = arr[:, idxs_adj] # (1,T,D) cond_t = torch.from_numpy(cond).float().squeeze(0) # (T,D) elif arr.ndim == 2: N, D = arr.shape idxs_adj = [min(i, N - 1) for i in idxs] cond = arr[idxs_adj] # (T,D) cond_t = torch.from_numpy(cond).float() else: raise ValueError(f"Unexpected mus shape: {arr.shape}") return cond_t, D def _load_ref_seg(self, A): """ seg npy: ๊ฐ 0~22 (0=background) ์ถ๋ ฅ: (1, 22, self.size, self.size) # N=1 ๋ฐฐ์น """ seg_path = os.path.join(self.seg_dir, f"{A}__0100_seg.npy") if not os.path.isfile(seg_path): raise FileNotFoundError(f"Missing seg file: {seg_path}") arr = np.load(seg_path) # (H,W) or (H,W,C) # ๋ง์ฝ ์ฑ๋์ด ์์ผ๋ฉด ์ฒซ ์ฑ๋๋ง ์ฌ์ฉ (๋ผ๋ฒจ ๋งต ๊ฐ์ ) if arr.ndim == 3: arr = arr[..., 0] # PIL๋ก ์ต๊ทผ์ ๋ฆฌ์ฌ์ด์ฆ (๋ ์ด๋ธ ๋ณด์กด) # ๊ฐ ๋ฒ์ 0~22์ด๋ฏ๋ก uint8๋ก ์์ img = Image.fromarray(arr.astype(np.uint8), mode="L") if img.size != (self.size, self.size): img = img.resize((self.size, self.size), Image.NEAREST) arr_resized = np.array(img, dtype=np.int64) # (H,W), ๊ฐ 0~22 # torch ํ
์๋ก ๋ณํ lab = torch.from_numpy(arr_resized).long() # (H,W) # one-hot: num_classes=23์ผ๋ก ๋ง๋ค๊ณ , ์ฑ๋ 0(๋ฐฐ๊ฒฝ) ๋ฒ๋ฆฌ๊ณ 1~22๋ง ์ฌ์ฉ oh23 = F.one_hot(lab, num_classes=23) # (H,W,23) oh_1_22 = oh23[..., 1:] # (H,W,22) # (22,H,W)๋ก permute ํ float32 ref = oh_1_22.permute(2, 0, 1).contiguous().float() # (22, H, W) return ref def _load_ref_seg_sapiens(self, A, idxs): """ Use Sapiens to segment a representative frame and return one-hot [22,H,W]. We take the middle frame index from idxs to align with current clip. Extra Sapiens classes beyond 22 are truncated. """ if self.sapiens is None: raise RuntimeError("Sapiens segmenter is not initialized") frame_dir = os.path.join(self.png_dir, A) pngs = sorted(glob.glob(os.path.join(frame_dir, "*.png"))) if not pngs: raise FileNotFoundError(f"No PNG frames in {frame_dir}") mid = idxs[len(idxs)//2] j = min(mid, len(pngs)-1) from PIL import Image img = Image.open(pngs[j]).convert("RGB") np_rgb = np.array(img) seg = self.sapiens.predict_mask(np_rgb) # (H,W) int # resize by nearest to target size if img.size != (self.size, self.size): imgL = Image.fromarray(seg.astype(np.int32), mode="I") imgL = imgL.resize((self.size, self.size), Image.NEAREST) seg = np.array(imgL, dtype=np.int64) lab = torch.from_numpy(seg).long() # Make one-hot with 23 classes and drop background channel 0 -> 22 channels expected by model oh = F.one_hot(lab, num_classes=max(23, int(lab.max().item())+1)) # (H,W,C) if oh.shape[-1] < 23: pad = torch.zeros((oh.shape[0], oh.shape[1], 23 - oh.shape[-1]), dtype=oh.dtype) oh = torch.cat([oh, pad], dim=-1) oh_1_22 = oh[..., 1:23] # (H,W,22) ref = oh_1_22.permute(2, 0, 1).contiguous().float() # (22,H,W) return ref def __getitem__(self, idx): A = self.ids[idx] # ์ฌํ ๊ฐ๋ฅํ ๋๋ค ์์ ์ธ๋ฑ์ค (์ํ id๋ณ ๊ณ ์ seed) rng = np.random.default_rng(self.seed + idx) # ๊ธธ์ด ํ์ธ์ฉ num_pngs = len(sorted(glob.glob(os.path.join(self.png_dir, A, "*.png")))) if num_pngs == 0: raise FileNotFoundError(f"No PNG frames in {os.path.join(self.png_dir, A)}") mus_path = os.path.join(self.mus_dir, f"{A}.npy") arr = np.load(mus_path, mmap_mode="r") num_mus = arr.shape[1] # N # ๊ณตํต ์ธ๋ฑ์ค ์ ํ idxs = self._choose_indices(num_pngs, num_mus, rng) # ๋ก๋ video = self._load_png_frames(A, idxs) # (3,T,H,W) cond, cond_dim = self._load_mus_cond(A, idxs) # (T,D) (D=4800) if self.ref_source == "seg": ref = self._load_ref_seg(A) # (22,H,W) elif self.ref_source == "sapiens": ref = self._load_ref_seg_sapiens(A, idxs) # (22,H,W) elif self.ref_source == "first_frame": ref = video[:, 0] # (3,H,W) else: raise ValueError(f"Unknown ref_source: {self.ref_source}") # ๋ชจ๋ธ cond_dim ํ์ธ/๋๊ธฐํ๋ ์ธ๋ถ์์ ์ฒ๋ฆฌ(์๋ ์ฐธ๊ณ ) return video, cond, ref # --------------------------- # Training / Inference # --------------------------- def train_one_epoch(model, sched, loader, optimizer, device, amp=True, p2_gamma=1.0, debug=True, cond_dropout=0.1): model.train() mse = nn.MSELoss(reduction="none") # <- per-sample ๊ฐ์ค ์ํด none scaler = torch.cuda.amp.GradScaler(enabled=amp) total_loss = 0.0 for batch_idx, (video, cond, ref) in enumerate(loader): video = video.to(device) # (B,3,T,H,W), in [-1,1] cond = cond.to(device) # (B,T,512) ref = ref.to(device) # (B,3,H,W) B = video.size(0) t = torch.randint(0, sched.cfg.num_train_timesteps, (B,), device=device, dtype=torch.long) noise = torch.randn_like(video) x_t = sched.add_noise(video, noise, t) optimizer.zero_grad(set_to_none=True) # Classifier-free conditioning dropout if cond_dropout > 0.0: drop_mask = (torch.rand((B,1), device=device) < cond_dropout).float() # zero-out cond/ref for a subset of batch cond_in = cond * (1.0 - drop_mask).unsqueeze(-1) ref_in = ref * (1.0 - drop_mask.view(B,1,1,1)) else: cond_in, ref_in = cond, ref with torch.cuda.amp.autocast(enabled=amp): eps_pred = model(x_t, t, cond_in, ref_in) # predict ฮต per_elem = mse(eps_pred, noise) # (B,3,T,H,W) per_sample = per_elem.mean(dim=(1,2,3,4)) # (B,) # ----- p2 weighting ----- abar = sched.alphas_cumprod[t] # (B,) snr = abar / (1.0 - abar) # (B,) weight = torch.pow(snr + 1.0, -p2_gamma) # (B,) # (์ ํ) ๋๋ฌด ์์/ํฐ ๊ฐ์ค์น ํด๋ฆฌํ # weight = weight.clamp_(min=1e-3) loss = (per_sample * weight).mean() if debug and batch_idx == 0: # ๊ธฐ๋ณธ ํต๊ณ ๋ฐ ํ
์ ํฌ๊ธฐ ํ์ธ (์ฒซ ๋ฐฐ์น๋ง) with torch.no_grad(): try: v_min, v_max = float(video.min().item()), float(video.max().item()) x_min, x_max = float(x_t.min().item()), float(x_t.max().item()) e_min, e_max = float(eps_pred.min().item()), float(eps_pred.max().item()) except Exception: v_min = v_max = x_min = x_max = e_min = e_max = float('nan') print("[DBG] batch=0 video", tuple(video.shape), f"range=[{v_min:.3f},{v_max:.3f}]") print("[DBG] batch=0 cond ", tuple(cond.shape)) print("[DBG] batch=0 ref ", tuple(ref.shape)) print("[DBG] batch=0 t ", tuple(t.shape), f"min={int(t.min().item())} max={int(t.max().item())}") print("[DBG] batch=0 x_t ", tuple(x_t.shape), f"range=[{x_min:.3f},{x_max:.3f}]") print("[DBG] batch=0 eps ", tuple(eps_pred.shape), f"range=[{e_min:.3f},{e_max:.3f}]") print("[DBG] batch=0 loss ", float(per_sample.mean().item())) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() * B return total_loss / len(loader.dataset) from tqdm import tqdm class EMA: def __init__(self, model, decay=0.999): self.decay = decay self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items() if v.dtype.is_floating_point} @torch.no_grad() def update(self, model): for k, v in model.state_dict().items(): if v.dtype.is_floating_point and k in self.shadow: self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay) @torch.no_grad() def copy_to(self, model): sd = model.state_dict() for k, v in self.shadow.items(): if k in sd: sd[k].copy_(v) model.load_state_dict(sd, strict=False) @torch.no_grad() def sample_video( model, sched, cond_seq, ref_img, device, size=640, frames=75, steps=None, show_progress: bool = True, desc: str = "Sampling", guidance_scale: float = 1.0, ): """ cond_seq: (1,T,D) ref_img : (1,C,H,W) # in [-1,1] return : (T,H,W,C) uint8 """ model.eval() steps = sched.cfg.num_train_timesteps if steps is None else steps x = torch.randn(1, 3, frames, size, size, device=device) iter_steps = range(steps - 1, -1, -1) # steps-1 ... 0 if show_progress: iter_steps = tqdm(iter_steps, total=steps, desc=desc, dynamic_ncols=True) for i in iter_steps: t = torch.full((1,), i, device=device, dtype=torch.long) if guidance_scale is None or guidance_scale <= 1.0: eps = model(x, t, cond_seq, ref_img) else: # Classifier-free guidance at inference: run with null cond/ref eps_uncond = model(x, t, torch.zeros_like(cond_seq), torch.zeros_like(ref_img)) eps_cond = model(x, t, cond_seq, ref_img) eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond) x = sched.step(eps, t, x) # ์ํ๋ฉด ์งํ๋ฐ์ ์ถ๊ฐ ์ ๋ณด ํ์ (์: ํ์ฌ t) if show_progress and hasattr(iter_steps, "set_postfix"): iter_steps.set_postfix(t=int(i)) x = x.clamp(-1, 1) x = (x * 127.5 + 127.5).round().byte() # [0,255] x = x[0].permute(1, 0, 2, 3).cpu().numpy() # (T,3,H,W) x = rearrange(x, "t c h w -> t h w c") return x def save_mp4(frames_uint8, path, fps=15): os.makedirs(os.path.dirname(path), exist_ok=True) iio.imwrite(path, frames_uint8, fps=fps, codec="h264") def load_model_only(path, model, map_location=None): if not os.path.isfile(path): raise FileNotFoundError(f"Checkpoint not found: {path}") state = torch.load(path, map_location=map_location or "cpu") unwrap_ddp(model).load_state_dict(state, strict=True) print(f"[CKPT] loaded model weights <- {path}") # --------------------------- # Main # --------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=1000) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--frames", type=int, default=75) parser.add_argument("--size", type=int, default=640) parser.add_argument("--base_ch", type=int, default=8) # ๋ฉ๋ชจ๋ฆฌ ์ค์ด๊ธฐ ์ํด ๊ธฐ๋ณธ 48 parser.add_argument("--depth", type=int, default=3) parser.add_argument("--save", type=str, default="runs/sample2.mp4") parser.add_argument("--fps", type=int, default=15) parser.add_argument("--no_amp", action="store_true") parser.add_argument("--ckpt_path", type=str, default="runs/ckpt_latest.pt") parser.add_argument("--use_ema", action="store_true") parser.add_argument("--cond_dropout", type=float, default=0.1) parser.add_argument("--guidance_scale", type=float, default=1.0) parser.add_argument("--beta_schedule", type=str, default="cosine", choices=["linear","cosine"]) # cosine by default parser.add_argument("--debug", action="store_true", help="I/O์ ํ
์ ํฌ๊ธฐ/๋ฒ์๋ฅผ ๋ก๊น
") # Sapiens-related CLI (optional) # Use with ds: set ref_source='sapiens' and provide config/checkpoint parser.add_argument("--ref_source", type=str, default="seg", choices=["seg","first_frame","sapiens"], help="Reference source type") parser.add_argument("--sapiens_seg_config", type=str, default=None) parser.add_argument("--sapiens_seg_checkpoint", type=str, default=None) parser.add_argument("--sapiens_device", type=str, default="cuda:0") args = parser.parse_args() is_ddp, rank, local_rank, world_size = setup_ddp() # ์์ ์๋ดํ ํจ์ # DDP ์ด๊ธฐํ ์ดํ device = f"cuda:{local_rank}" if is_ddp else ("cuda" if torch.cuda.is_available() else "cpu") model = SpatioTemporalUNet( in_ch=3, base_ch=args.base_ch, depth=args.depth, cond_dim=4800, t_emb_dim=256, ).to(device) if is_ddp: model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) optim = torch.optim.AdamW(model.parameters(), lr=args.lr) ema = EMA(unwrap_ddp(model)) if args.use_ema else None ds = VideoMusStrideDataset( root="/root/AIST_processed", frames=args.frames, size=args.size, stride=2, ref_source=args.ref_source, # 'seg' | 'first_frame' | 'sapiens' ids=["gBR_sBM_cAll_d04_mBR1_ch03", "gBR_sBM_cAll_d04_mBR1_ch10", "gBR_sBM_cAll_d04_mBR2_ch03", "gBR_sBM_cAll_d04_mBR2_ch07" ], # ํน์ id๋ง ํ์ต sapiens_config=args.sapiens_seg_config, sapiens_checkpoint=args.sapiens_seg_checkpoint, sapiens_device=args.sapiens_device, ) if is_ddp: sampler = DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=True) shuffle = False else: sampler = None shuffle = True dl = DataLoader( ds, batch_size=args.batch_size, shuffle=shuffle, sampler=sampler, num_workers=16 if is_ddp else 8, pin_memory=True, persistent_workers=True if (is_ddp) else False, prefetch_factor=1, ) sched = DDPMSchedulerSimple(DDPMConfig(beta_schedule=args.beta_schedule)).to(device) if args.debug: print("[DBG] device:", device) tot_params = sum(p.numel() for p in unwrap_ddp(model).parameters()) train_params = sum(p.numel() for p in unwrap_ddp(model).parameters() if p.requires_grad) print(f"[DBG] model params: total={tot_params:,} trainable={train_params:,}") print(f"[DBG] dataset size: {len(ds)} ids={getattr(ds,'ids', None) if hasattr(ds,'ids') else 'n/a'}") print(f"[DBG] frames={args.frames} size={args.size} base_ch={args.base_ch} depth={args.depth}") print(f"[DBG] dataloader: batch_size={args.batch_size} workers={16 if is_ddp else 8} shuffle={shuffle}") try: batch = next(iter(dl)) video_b, cond_b, ref_b = batch print("[DBG] sample batch video:", tuple(video_b.shape), f"range=[{float(video_b.min()):.3f},{float(video_b.max()):.3f}]") print("[DBG] sample batch cond :", tuple(cond_b.shape)) print("[DBG] sample batch ref :", tuple(ref_b.shape), f"range=[{float(ref_b.min()):.3f},{float(ref_b.max()):.3f}]") except Exception as e: print("[DBG] failed to draw sample batch:", repr(e)) # train for e in tqdm(range(args.epochs), desc="Epochs", unit="epoch"): if is_ddp and sampler is not None: sampler.set_epoch(e) loss = train_one_epoch( model, sched, dl, optim, device, # e.g., f"cuda:{local_rank}" or "cuda" amp=(not args.no_amp), # ํผํฉ์ ๋ฐ ์ต์
p2_gamma=1.0, # p2 ๊ฐ์ค(์ ํ) debug=(args.debug and ((not is_ddp) or rank == 0)), cond_dropout=args.cond_dropout, ) if ema is not None: ema.update(unwrap_ddp(model)) if not is_ddp or rank == 0: print(f"[Epoch {e+1}] loss={loss:.4f}") # โ
์ฌ๊ธฐ์ ์ ์ฅ (rank0๋ง) if not is_ddp or rank == 0: save_model_only(args.ckpt_path, model) print(f"[CKPT] saved model weights -> {args.ckpt_path}") # Build a conditioning pair from dataset (first id) try: _, cond_s, ref_s = ds[0] except Exception as e: raise RuntimeError(f"Failed to fetch a sample for inference: {e}") cond = cond_s.unsqueeze(0).to(device) # (1,T,D) ref = ref_s.unsqueeze(0).to(device) # (1,C,H,W) # Choose model for inference (EMA if available) model_infer = unwrap_ddp(model) if ema is not None: model_ema = SpatioTemporalUNet( in_ch=3, base_ch=args.base_ch, depth=args.depth, cond_dim=4800, t_emb_dim=256, ).to(device) ema.copy_to(model_ema) model_infer = model_ema # Sample frames_uint8 = sample_video( model_infer, sched, cond, ref, device, size=args.size, frames=args.frames, steps=None, show_progress=True, desc="DDPM Inference", guidance_scale=args.guidance_scale ) # Save save_mp4(frames_uint8, args.save, fps=args.fps) print(f"[OK] saved video -> {args.save}") return if __name__ == "__main__": main() # CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py --epochs 100 --batch_size 4
\ No newline at end of file
diff --git a/animation/modules/unet.py b/animation/modules/unet.py
index 52a8725..56a5905 100644
--- a/animation/modules/unet.py
+++ b/animation/modules/unet.py
@@ -329,7 +329,9 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor)
- def _set_gradient_checkpointing(self, module, value=False):
+ def _set_gradient_checkpointing(self, module, value=False, enable=None):
+ if enable is not None:
+ module.gradient_checkpointing = True
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
diff --git a/animation/modules/unet_new.py b/animation/modules/unet_new.py
new file mode 100644
index 0000000..7a3dd00
--- /dev/null
+++ b/animation/modules/unet_new.py
@@ -0,0 +1,589 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput, logging
+
+from animation.modules.unet_3d_blocks import get_down_block, UNetMidBlockSpatioTemporal, get_up_block
+# from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetSpatioTemporalConditionOutput(BaseOutput):
+ """
+ The output of [`UNetSpatioTemporalConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,
+ and a timestep and returns a sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. " \
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. " \
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. " \
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=1e-5,
+ resolution_idx=i,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ self.conv_act = nn.SiLU()
+
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` " \
+ f"when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ pose_latents: torch.Tensor = None,
+ image_only_indicator: bool = False,
+ return_dict: bool = True,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ pose_latents: (`torch.FloatTensor`):
+ The additional latents for pose sequences.
+ image_only_indicator (`bool`, *optional*, defaults to `False`):
+ Whether or not training with all images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
+ instead of a plain tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True,
+ an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ pose_latents: torch.Tensor = None,
+ image_only_indicator: bool = False,
+ return_dict: bool = True,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ pose_latents: (`torch.FloatTensor`):
+ The additional latents for pose sequences.
+ image_only_indicator (`bool`, *optional*, defaults to `False`):
+ Whether or not training with all images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
+ instead of a plain tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True,
+ an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+ if pose_latents is not None:
+ sample = sample + pose_latents
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+ if pose_latents is not None:
+ sample = sample + pose_latents
+
+ image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
+ if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/animation/pipelines/inference_pipeline_animation.py b/animation/pipelines/inference_pipeline_animation.py
index 7d6c412..1d852d4 100644
--- a/animation/pipelines/inference_pipeline_animation.py
+++ b/animation/pipelines/inference_pipeline_animation.py
@@ -86,7 +86,7 @@ def __init__(
unet,
scheduler,
feature_extractor,
- pose_net,
+ music_encoder,
face_encoder,
):
super().__init__()
@@ -97,7 +97,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
- pose_net=pose_net,
+ music_encoder=music_encoder,
face_encoder=face_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -363,7 +363,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
- image_pose: Union[torch.FloatTensor],
+ music: torch.FloatTensor,
height: int = 576,
width: int = 1024,
num_frames: Optional[int] = None,
@@ -594,19 +594,19 @@ def __call__(
if indices[-1][-1] < num_frames - 1:
indices.append([0, *range(num_frames - tile_size + 1, num_frames)])
- pose_pil_image_list = []
- for pose in image_pose:
- pose = torch.from_numpy(np.array(pose)).float()
- pose = pose / 127.5 - 1
- pose_pil_image_list.append(pose)
- pose_pil_image_list = torch.stack(pose_pil_image_list, dim=0)
- pose_pil_image_list = rearrange(pose_pil_image_list, "f h w c -> f c h w")
+ # pose_pil_image_list = []
+ # for pose in image_pose:
+ # pose = torch.from_numpy(np.array(pose)).float()
+ # pose = pose / 127.5 - 1
+ # pose_pil_image_list.append(pose)
+ # pose_pil_image_list = torch.stack(pose_pil_image_list, dim=0)
+ # pose_pil_image_list = rearrange(pose_pil_image_list, "f h w c -> f c h w")
# print(indices) # [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]
# print(pose_pil_image_list.size()) # [16, 3, 512, 512]
- self.pose_net.to(device)
+ self.music_encoder.to(device)
self.unet.to(device)
with torch.cuda.device(device):
@@ -628,7 +628,8 @@ def __call__(
weight = torch.minimum(weight, 2 - weight)
for idx in indices:
# classification-free inference
- pose_latents = self.pose_net(pose_pil_image_list[idx].to(device=device, dtype=latent_model_input.dtype))
+ x_music = music[idx, :].unsqueeze(0).to(device, torch.float16)
+ music_latents = self.music_encoder(x_music)
_noise_pred = self.unet(
latent_model_input[:1, idx],
t,
@@ -646,7 +647,7 @@ def __call__(
t,
encoder_hidden_states=image_embeddings[1:],
added_time_ids=added_time_ids[1:],
- pose_latents=pose_latents,
+ pose_latents=music_latents,
image_only_indicator=image_only_indicator,
return_dict=False,
)[0]
diff --git a/assets/figures/case-17.gif b/assets/figures/case-17.gif
deleted file mode 100644
index e5a632f..0000000
Binary files a/assets/figures/case-17.gif and /dev/null differ
diff --git a/assets/figures/case-18.gif b/assets/figures/case-18.gif
deleted file mode 100644
index 69777c0..0000000
Binary files a/assets/figures/case-18.gif and /dev/null differ
diff --git a/assets/figures/case-24.gif b/assets/figures/case-24.gif
deleted file mode 100644
index e27342c..0000000
Binary files a/assets/figures/case-24.gif and /dev/null differ
diff --git a/assets/figures/case-35.gif b/assets/figures/case-35.gif
deleted file mode 100644
index bdab604..0000000
Binary files a/assets/figures/case-35.gif and /dev/null differ
diff --git a/assets/figures/case-42.gif b/assets/figures/case-42.gif
deleted file mode 100644
index 76db98f..0000000
Binary files a/assets/figures/case-42.gif and /dev/null differ
diff --git a/assets/figures/case-45.gif b/assets/figures/case-45.gif
deleted file mode 100644
index 66db02f..0000000
Binary files a/assets/figures/case-45.gif and /dev/null differ
diff --git a/assets/figures/case-46.gif b/assets/figures/case-46.gif
deleted file mode 100644
index faa7d55..0000000
Binary files a/assets/figures/case-46.gif and /dev/null differ
diff --git a/assets/figures/case-47.gif b/assets/figures/case-47.gif
deleted file mode 100644
index 8f55bdd..0000000
Binary files a/assets/figures/case-47.gif and /dev/null differ
diff --git a/assets/figures/case-5.gif b/assets/figures/case-5.gif
deleted file mode 100644
index 060f872..0000000
Binary files a/assets/figures/case-5.gif and /dev/null differ
diff --git a/assets/figures/case-61.gif b/assets/figures/case-61.gif
deleted file mode 100644
index 11f3d2b..0000000
Binary files a/assets/figures/case-61.gif and /dev/null differ
diff --git a/assets/figures/framework.jpg b/assets/figures/framework.jpg
deleted file mode 100644
index ed44806..0000000
Binary files a/assets/figures/framework.jpg and /dev/null differ
diff --git a/checkpoints b/checkpoints
new file mode 160000
index 0000000..172658c
--- /dev/null
+++ b/checkpoints
@@ -0,0 +1 @@
+Subproject commit 172658c01f49905d09f33c8872cc701e3f6d9ec5
diff --git a/command_basic_infer.sh b/command_basic_infer.sh
index 03bf6f2..6c92e3b 100644
--- a/command_basic_infer.sh
+++ b/command_basic_infer.sh
@@ -1,15 +1,16 @@
CUDA_VISIBLE_DEVICES=0 python inference_basic.py \
- --pretrained_model_name_or_path="path/checkpoints/SVD/stable-video-diffusion-img2vid-xt" \
- --output_dir="path/basic_infer" \
- --validation_control_folder="path/inference/case-1/poses" \
- --validation_image="path/inference/case-1/reference.png" \
- --width=576 \
- --height=1024 \
+ --pretrained_model_name_or_path="/root/StableAnimator_Music/checkpoints/stable-video-diffusion-img2vid-xt" \
+ --output_dir="/root/basic_infer" \
+ --validation_music="/root/dataset/rec/gLH_sBM_cAll_d16_mLH0_ch09.h5" \
+ --validation_image="/root/dataset/rec/gKR_sBM_cAll_d28_mKR1_ch07/images/square_cropped/0444.png" \
+ --width=512 \
+ --height=512 \
+ --length=100 \
--guidance_scale=3.0 \
--num_inference_steps=25 \
- --posenet_model_name_or_path="path/checkpoints/Animation/pose_net.pth" \
- --face_encoder_model_name_or_path="path/checkpoints/Animation/face_encoder.pth" \
- --unet_model_name_or_path="path/checkpoints/Animation/unet.pth" \
+ --music_encoder_model_name_or_path="/root/StableAnimator_Music/checkpoints/Animation/checkpoint-70000/music_encoder-70000.pth" \
+ --face_encoder_model_name_or_path="/root/StableAnimator_Music/checkpoints/Animation/checkpoint-70000/face_encoder-70000.pth" \
+ --unet_model_name_or_path="/root/StableAnimator_Music/checkpoints/Animation/checkpoint-70000/unet-70000.pth" \
--tile_size=16 \
--overlap=4 \
--noise_aug_strength=0.02 \
diff --git a/command_train.sh b/command_train.sh
index 42eff7c..fb5ac7d 100644
--- a/command_train.sh
+++ b/command_train.sh
@@ -1,22 +1,23 @@
-CUDA_VISIBLE_DEVICES=3,2,1,0 accelerate launch train.py \
- --pretrained_model_name_or_path="path/checkpoints/SVD/stable-video-diffusion-img2vid-xt" \
- --output_dir="path/checkpoints/Animation" \
- --data_root_path="path/animation_data" \
- --rec_data_path="path/animation_data/video_rec_path.txt" \
- --vec_data_path="path/animation_data/video_vec_path.txt" \
- --validation_image_folder="path/validation/ground_truth" \
- --validation_control_folder="path/validation/poses" \
- --validation_image="path/validation/reference.png" \
- --num_workers=8 \
- --lr_warmup_steps=500 \
+CUDA_VISIBLE_DEVICES=0,1 accelerate launch /root/StableAnimator_Music/train_single.py \
+ --pretrained_model_name_or_path="/root/StableAnimator_Music/checkpoints/stable-video-diffusion-img2vid-xt" \
+ --output_dir="./checkpoints/Animation" \
+ --data_root_path="/root/aist_hdf5/rec" \
+ --data_path="/root/aist_hdf5/full_list.txt" \
+ --dataset_width=512 \
+ --dataset_height=512 \
+ --validation_image_folder="./validation/ground_truth" \
+ --validation_control_folder="./validation/poses" \
+ --validation_image="./validation/reference.png" \
+ --num_workers=2 \
+ --lr_warmup_steps=10 \
--sample_n_frames=16 \
--learning_rate=1e-5 \
--per_gpu_batch_size=1 \
- --num_train_epochs=6000 \
+ --num_train_epochs=10000 \
--mixed_precision="fp16" \
--gradient_accumulation_steps=1 \
--checkpointing_steps=2000 \
--validation_steps=500 \
- --gradient_checkpointing \
--checkpoints_total_limit=5000 \
- --resume_from_checkpoint="latest"
\ No newline at end of file
+ --resume_from_checkpoint="latest" \
+ --max_train_steps=80000
\ No newline at end of file
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..6acf5c7
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,531 @@
+#!/usr/bin/env python3
+# evaluate_fluency.py
+
+"""
+Temporal Fluency Metric for Video Generation Evaluation
+
+์ด ์คํฌ๋ฆฝํธ๋ ์ธ๊ทธ๋ฉํ
์ด์
๊ธฐ๋ฐ ์๊ฐ์ ์ ์ฐฝ์ฑ(Temporal Fluency)์ ์ธก์ ํฉ๋๋ค.
+
+์์:
+ c_t = |B_t โ B_{t+1}| / |B_t โช B_{t+1}| (ํ๋ ์ ๋ณํ ๋น์จ)
+ ฮผ_V = mean(c_t) (ํ๊ท ๋ณํ๋)
+ ฯยฒ_V = var(c_t) (๋ณํ๋ ๋ถ์ฐ)
+ F_V = 1 / (1 + ฮผ_V + ฯยฒ_V) (์ ์ฐฝ์ฑ ์ ์)
+
+์ ์ ํด์:
+ - 0.90~1.00: ์ต์ (์ค์ ๋น๋์ค ์์ค)
+ - 0.80~0.90: ์ฐ์ (์์ฐ์ค๋ฌ์)
+ - 0.70~0.80: ์ํธ (์ฝ๊ฐ์ ๋๊น)
+ - 0.60~0.70: ๋ณดํต (๋์ ๋๋ ๋๊น)
+ - < 0.60: ๋ถ๋ (์ฌํ ๋๊น)
+
+Author: [Your Name]
+Date: 2025-01-28
+"""
+
+import numpy as np
+import os
+import glob
+from pathlib import Path
+import argparse
+from tqdm import tqdm
+import json
+import pandas as pd
+from typing import List, Dict, Tuple, Optional
+import warnings
+
+warnings.filterwarnings('ignore')
+
+
+class FluencyEvaluator:
+ """๋น๋์ค ์ ์ฐฝ์ฑ ํ๊ฐ ํด๋์ค"""
+
+ def __init__(self, fps: int = 30, window_seconds: int = 5):
+ """
+ Args:
+ fps: ํ๋ ์๋ ์ดํธ (frames per second)
+ window_seconds: ํ๊ฐ ์๋์ฐ ํฌ๊ธฐ (์ด)
+ """
+ self.fps = fps
+ self.window_seconds = window_seconds
+ self.window_frames = fps * window_seconds
+
+ @staticmethod
+ def load_segmentation(npy_path: str) -> Optional[np.ndarray]:
+ """
+ ์ธ๊ทธ๋ฉํ
์ด์
.npy ํ์ผ ๋ก๋
+
+ Args:
+ npy_path: .npy ํ์ผ ๊ฒฝ๋ก
+
+ Returns:
+ body_mask: ์ ์ฒด ์์ญ ์ด์ง ๋ง์คํฌ (H, W) ๋๋ None
+ """
+ try:
+ seg = np.load(npy_path)
+ # ๋ฐฐ๊ฒฝ(0) ์ ์ธ, ์ ์ฒด ๋ถ์๋ง (>0)
+ body_mask = seg > 0
+ return body_mask
+ except Exception as e:
+ print(f"Error loading {npy_path}: {e}")
+ return None
+
+ @staticmethod
+ def compute_frame_change(mask1: np.ndarray, mask2: np.ndarray) -> Optional[float]:
+ """
+ ๋ ํ๋ ์ ๊ฐ ์ ์ฒด ํฝ์
๋ณํ ๋น์จ ๊ณ์ฐ
+
+ ์์: c_t = |B_t โ B_{t+1}| / |B_t โช B_{t+1}|
+
+ Args:
+ mask1: ํ๋ ์ t์ ์ ์ฒด ๋ง์คํฌ
+ mask2: ํ๋ ์ t+1์ ์ ์ฒด ๋ง์คํฌ
+
+ Returns:
+ change_ratio: ๋ณํ ๋น์จ [0, 1]
+ """
+ if mask1 is None or mask2 is None:
+ return None
+
+ # XOR ์ฐ์ฐ: ๋ณํ๋ ํฝ์
(๋์นญ ์ฐจ์งํฉ)
+ changed_pixels = np.logical_xor(mask1, mask2)
+
+ # Union: ์ ์ฒด ์ ์ฒด ํฝ์
+ total_body_pixels = np.logical_or(mask1, mask2).sum()
+
+ if total_body_pixels == 0:
+ return 0.0
+
+ change_ratio = changed_pixels.sum() / total_body_pixels
+ return float(change_ratio)
+
+ def compute_temporal_fluency(
+ self,
+ seg_files: List[str]
+ ) -> Tuple[List[float], List[Dict], float, float, float]:
+ """
+ ์๊ฐ์ ์ ์ฐฝ์ฑ ๊ณ์ฐ
+
+ Args:
+ seg_files: ์ ๋ ฌ๋ ์ธ๊ทธ๋ฉํ
์ด์
ํ์ผ ๊ฒฝ๋ก ๋ฆฌ์คํธ
+
+ Returns:
+ frame_changes: ํ๋ ์๋ณ ๋ณํ๋ [c_1, c_2, ..., c_{N-1}]
+ window_fluency: ์๋์ฐ๋ณ ์ ์ฐฝ์ฑ ์ ๋ณด
+ overall_fluency: ์ ์ฒด ์ ์ฐฝ์ฑ ์ ์ F_V
+ avg_change: ํ๊ท ๋ณํ๋ ฮผ_V
+ variance: ๋ณํ๋ ๋ถ์ฐ ฯยฒ_V
+ """
+ # 1. ํ๋ ์๋ณ ๋ณํ๋ ๊ณ์ฐ
+ frame_changes = []
+ prev_mask = None
+
+ print("ํ๋ ์๋ณ ๋ณํ๋ ๊ณ์ฐ ์ค...")
+ for seg_file in tqdm(seg_files, desc="Processing frames"):
+ curr_mask = self.load_segmentation(seg_file)
+
+ if prev_mask is not None and curr_mask is not None:
+ change = self.compute_frame_change(prev_mask, curr_mask)
+ if change is not None:
+ frame_changes.append(change)
+
+ prev_mask = curr_mask
+
+ if len(frame_changes) == 0:
+ print("โ ๏ธ ์ ํจํ ํ๋ ์ ๋ณํ๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค.")
+ return [], [], 0.0, 0.0, 0.0
+
+ # 2. ์ ์ฒด ํต๊ณ ๊ณ์ฐ
+ avg_change = float(np.mean(frame_changes))
+ variance = float(np.var(frame_changes))
+ overall_fluency = 1.0 / (1.0 + avg_change + variance)
+
+ # 3. ์๋์ฐ๋ณ ์ ์ฐฝ์ฑ ๊ณ์ฐ
+ window_fluency = []
+
+ print(f"\n{self.window_seconds}์ด ์๋์ฐ๋ณ ์ ์ฐฝ์ฑ ๊ณ์ฐ ์ค...")
+ for i in range(0, len(frame_changes), self.window_frames):
+ window = frame_changes[i:i+self.window_frames]
+
+ if len(window) > 0:
+ w_avg = float(np.mean(window))
+ w_var = float(np.var(window))
+ w_fluency = 1.0 / (1.0 + w_avg + w_var)
+
+ window_fluency.append({
+ 'window_id': len(window_fluency) + 1,
+ 'start_frame': i,
+ 'end_frame': min(i + self.window_frames, len(frame_changes)),
+ 'num_frames': len(window),
+ 'avg_change': w_avg,
+ 'variance': w_var,
+ 'fluency_score': w_fluency
+ })
+
+ return frame_changes, window_fluency, overall_fluency, avg_change, variance
+
+ def evaluate_folder(
+ self,
+ seg_folder: str,
+ output_dir: Optional[str] = None,
+ verbose: bool = True
+ ) -> Optional[Dict]:
+ """
+ ํด๋ ๋ด ๋ชจ๋ ์ธ๊ทธ๋ฉํ
์ด์
ํ๊ฐ
+
+ Args:
+ seg_folder: ์ธ๊ทธ๋ฉํ
์ด์
ํด๋ ๊ฒฝ๋ก
+ output_dir: ๊ฒฐ๊ณผ ์ ์ฅ ๋๋ ํ ๋ฆฌ (None์ด๋ฉด ์ ์ฅ ์ ํจ)
+ verbose: ์์ธ ์ถ๋ ฅ ์ฌ๋ถ
+
+ Returns:
+ results: ํ๊ฐ ๊ฒฐ๊ณผ ๋์
๋๋ฆฌ
+ """
+ # .npy ํ์ผ ์ฐพ๊ธฐ ๋ฐ ์ ๋ ฌ
+ seg_files = sorted(glob.glob(os.path.join(seg_folder, '*_seg.npy')))
+
+ if len(seg_files) == 0:
+ print(f"โ {seg_folder}์์ ์ธ๊ทธ๋ฉํ
์ด์
ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค.")
+ return None
+
+ if verbose:
+ print(f"\n{'='*70}")
+ print(f"๐ ํด๋: {seg_folder}")
+ print(f"๐ ์ด ํ๋ ์ ์: {len(seg_files)}")
+ print(f"๐ฌ FPS: {self.fps}")
+ print(f"โฑ๏ธ ์๋์ฐ: {self.window_seconds}์ด")
+ print('='*70)
+
+ # ์ ์ฐฝ์ฑ ๊ณ์ฐ
+ frame_changes, window_fluency, overall_fluency, avg_change, variance = \
+ self.compute_temporal_fluency(seg_files)
+
+ if len(frame_changes) == 0:
+ return None
+
+ # ๊ฒฐ๊ณผ ์ถ๋ ฅ
+ if verbose:
+ self._print_results(
+ overall_fluency, avg_change, variance,
+ frame_changes, window_fluency
+ )
+
+ # ๊ฒฐ๊ณผ ์ ์ฅ
+ folder_name = Path(seg_folder).parent.parent.name
+ # folder_name = Path(seg_folder).parent.name
+
+ results = {
+ 'folder': seg_folder,
+ 'folder_name': folder_name,
+ 'total_frames': len(seg_files),
+ 'fps': self.fps,
+ 'window_seconds': self.window_seconds,
+ 'overall_fluency': overall_fluency,
+ 'avg_change': avg_change,
+ 'variance': variance,
+ 'min_change': float(np.min(frame_changes)),
+ 'max_change': float(np.max(frame_changes)),
+ 'median_change': float(np.median(frame_changes)),
+ 'std_change': float(np.std(frame_changes)),
+ 'frame_changes': frame_changes,
+ 'window_fluency': window_fluency,
+ 'grade': self._get_grade(overall_fluency)
+ }
+
+ if output_dir:
+ self._save_results(results, output_dir, folder_name)
+
+ return results
+
+ def _print_results(
+ self,
+ overall_fluency: float,
+ avg_change: float,
+ variance: float,
+ frame_changes: List[float],
+ window_fluency: List[Dict]
+ ):
+ """๊ฒฐ๊ณผ ์ถ๋ ฅ"""
+ print("\n" + "="*70)
+ print("๐ ํ๊ฐ ๊ฒฐ๊ณผ")
+ print("="*70)
+ print(f"์ ์ฒด ์ ์ฐฝ์ฑ ์ ์ (F_V): {overall_fluency:.4f}")
+ print(f"ํ๊ท ๋ณํ๋ (ฮผ_V): {avg_change:.4f}")
+ print(f"๋ณํ๋ ๋ถ์ฐ (ฯยฒ_V): {variance:.4f}")
+ print(f"๋ณํ๋ ๋ฒ์: [{np.min(frame_changes):.4f}, {np.max(frame_changes):.4f}]")
+ print(f"๋ณํ๋ ์ค์๊ฐ: {np.median(frame_changes):.4f}")
+ print(f"๋ณํ๋ ํ์คํธ์ฐจ: {np.std(frame_changes):.4f}")
+ print(f"๋ฑ๊ธ: {self._get_grade(overall_fluency)}")
+
+ print(f"\n{self.window_seconds}์ด ์๋์ฐ๋ณ ์ ์ฐฝ์ฑ:")
+ for w in window_fluency:
+ print(f" ์๋์ฐ {w['window_id']} "
+ f"(ํ๋ ์ {w['start_frame']}-{w['end_frame']}): "
+ f"F={w['fluency_score']:.4f}, "
+ f"ฮผ={w['avg_change']:.4f}, "
+ f"ฯยฒ={w['variance']:.4f}")
+
+ @staticmethod
+ def _get_grade(fluency: float) -> str:
+ """์ ์ฐฝ์ฑ ์ ์๋ฅผ ๋ฑ๊ธ์ผ๋ก ๋ณํ"""
+ if fluency >= 0.90:
+ return "Excellent"
+ elif fluency >= 0.80:
+ return "Good"
+ elif fluency >= 0.70:
+ return "Fair"
+ elif fluency >= 0.60:
+ return "Poor"
+ else:
+ return "Very Poor"
+
+ def _save_results(self, results: Dict, output_dir: str, folder_name: str):
+ """๊ฒฐ๊ณผ ์ ์ฅ"""
+ os.makedirs(output_dir, exist_ok=True)
+
+ # 1. JSON ์ ์ฅ (์์ธ ์ ๋ณด)
+ json_file = os.path.join(output_dir, f"{folder_name}_fluency.json")
+
+ # frame_changes๋ ์ฉ๋์ด ํฌ๋ฏ๋ก ๋ณ๋ ์ ์ฅ ์ต์
+ json_results = results.copy()
+ if len(results['frame_changes']) > 1000:
+ # ํ๋ ์์ด ๋ง์ผ๋ฉด ํต๊ณ๋ง ์ ์ฅ
+ json_results['frame_changes'] = {
+ 'note': 'Too many frames. Statistics only.',
+ 'count': len(results['frame_changes']),
+ 'sample_first_10': results['frame_changes'][:10],
+ 'sample_last_10': results['frame_changes'][-10:]
+ }
+
+ with open(json_file, 'w') as f:
+ json.dump(json_results, f, indent=2)
+
+ print(f"\nโ
JSON ์ ์ฅ: {json_file}")
+
+ # 2. CSV ์ ์ฅ/์
๋ฐ์ดํธ (์์ฝ ์ ๋ณด)
+ csv_file = os.path.join(output_dir, "fluency_summary.csv")
+ self._update_csv(csv_file, results)
+ print(f"โ
CSV ์
๋ฐ์ดํธ: {csv_file}")
+
+ @staticmethod
+ def _update_csv(csv_file: str, results: Dict):
+ """CSV ํ์ผ ์
๋ฐ์ดํธ (๊ธฐ์กด ํ ๋ฎ์ด์ฐ๊ธฐ ๋๋ ์ ํ ์ถ๊ฐ)"""
+ folder_name = results['folder_name']
+
+ new_row = {
+ 'Folder': folder_name,
+ 'Overall_Fluency': f"{results['overall_fluency']:.4f}",
+ 'Avg_Change': f"{results['avg_change']:.4f}",
+ 'Variance': f"{results['variance']:.4f}",
+ 'Min_Change': f"{results['min_change']:.4f}",
+ 'Max_Change': f"{results['max_change']:.4f}",
+ 'Median_Change': f"{results['median_change']:.4f}",
+ 'Std_Change': f"{results['std_change']:.4f}",
+ 'Frames': results['total_frames'],
+ 'FPS': results['fps'],
+ 'Window_Sec': results['window_seconds'],
+ 'Grade': results['grade']
+ }
+
+ # CSV ํ์ผ ์กด์ฌ ์ฌ๋ถ ํ์ธ
+ if os.path.exists(csv_file):
+ df = pd.read_csv(csv_file)
+
+ # ๊ฐ์ ํด๋๊ฐ ์ด๋ฏธ ์๋์ง ํ์ธ
+ if folder_name in df['Folder'].values:
+ # ๊ธฐ์กด ํ ์
๋ฐ์ดํธ
+ for col, value in new_row.items():
+ if col in df.columns:
+ df.loc[df['Folder'] == folder_name, col] = value
+ else:
+ df[col] = None
+ df.loc[df['Folder'] == folder_name, col] = value
+ print(f" โน๏ธ ๊ธฐ์กด ํ ์
๋ฐ์ดํธ: {folder_name}")
+ else:
+ # ์๋ก์ด ํ ์ถ๊ฐ
+ df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
+ print(f" โ ์๋ก์ด ํ ์ถ๊ฐ: {folder_name}")
+ else:
+ # ์ CSV ํ์ผ ์์ฑ
+ df = pd.DataFrame([new_row])
+ print(f" ๐ ์ CSV ํ์ผ ์์ฑ")
+
+ # CSV ์ ์ฅ
+ df.to_csv(csv_file, index=False)
+
+ def batch_evaluate(
+ self,
+ parent_folder: str,
+ output_dir: Optional[str] = None,
+ pattern: str = '*/images/sapiens_seg'
+ ):
+ """
+ ์ฌ๋ฌ ํด๋ ์ผ๊ด ํ๊ฐ
+
+ Args:
+ parent_folder: ๋ถ๋ชจ ํด๋ ๊ฒฝ๋ก
+ output_dir: ๊ฒฐ๊ณผ ์ ์ฅ ๋๋ ํ ๋ฆฌ
+ pattern: ์ธ๊ทธ๋ฉํ
์ด์
ํด๋ ๊ฒ์ ํจํด
+ """
+ # sapiens_seg ํด๋ ์ฐพ๊ธฐ
+ seg_folders = glob.glob(os.path.join(parent_folder, pattern))
+
+ if len(seg_folders) == 0:
+ print(f"โ {parent_folder}์์ '{pattern}' ํด๋๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.")
+ return
+
+ print(f"\n{'โ'*70}")
+ print(f"๐ ๋ฐ๊ฒฌ๋ ํด๋: {len(seg_folders)}๊ฐ")
+ print('โ'*70)
+
+ all_results = []
+ failed = []
+
+ for i, seg_folder in enumerate(seg_folders, 1):
+ print(f"\n\n{'='*70}")
+ print(f"[{i}/{len(seg_folders)}]")
+ print('='*70)
+
+ try:
+ result = self.evaluate_folder(seg_folder, output_dir, verbose=True)
+ if result:
+ all_results.append(result)
+ else:
+ failed.append(seg_folder)
+ except Exception as e:
+ print(f"โ ์๋ฌ ๋ฐ์: {e}")
+ failed.append(seg_folder)
+
+ # ์ ์ฒด ์์ฝ
+ self._print_summary(all_results, failed)
+
+ @staticmethod
+ def _print_summary(all_results: List[Dict], failed: List[str]):
+ """์ ์ฒด ์์ฝ ์ถ๋ ฅ"""
+ if not all_results:
+ print("\nโ ๏ธ ํ๊ฐ๋ ๊ฒฐ๊ณผ๊ฐ ์์ต๋๋ค.")
+ return
+
+ print("\n\n" + "="*70)
+ print("๐ ์ ์ฒด ์์ฝ")
+ print("="*70)
+
+ fluency_scores = [r['overall_fluency'] for r in all_results]
+ avg_changes = [r['avg_change'] for r in all_results]
+ variances = [r['variance'] for r in all_results]
+
+ print(f"์ด ํ๊ฐ: {len(all_results)}๊ฐ")
+ print(f"์คํจ: {len(failed)}๊ฐ")
+ print()
+
+ print("์ ์ฐฝ์ฑ ์ ์ (F_V):")
+ print(f" ํ๊ท : {np.mean(fluency_scores):.4f}")
+ print(f" ์ค์๊ฐ: {np.median(fluency_scores):.4f}")
+ print(f" ํ์คํธ์ฐจ: {np.std(fluency_scores):.4f}")
+ print(f" ๋ฒ์: [{np.min(fluency_scores):.4f}, {np.max(fluency_scores):.4f}]")
+ print()
+
+ print("ํ๊ท ๋ณํ๋ (ฮผ_V):")
+ print(f" ํ๊ท : {np.mean(avg_changes):.4f}")
+ print(f" ์ค์๊ฐ: {np.median(avg_changes):.4f}")
+ print()
+
+ print("๋ถ์ฐ (ฯยฒ_V):")
+ print(f" ํ๊ท : {np.mean(variances):.4f}")
+ print(f" ์ค์๊ฐ: {np.median(variances):.4f}")
+ print()
+
+ # ๋ฑ๊ธ๋ณ ๋ถํฌ
+ grades = [r['grade'] for r in all_results]
+ grade_counts = pd.Series(grades).value_counts()
+ print("๋ฑ๊ธ ๋ถํฌ:")
+ for grade in ['Excellent', 'Good', 'Fair', 'Poor', 'Very Poor']:
+ count = grade_counts.get(grade, 0)
+ percentage = (count / len(all_results)) * 100
+ print(f" {grade}: {count}๊ฐ ({percentage:.1f}%)")
+
+ if failed:
+ print(f"\nโ ๏ธ ์คํจํ ํด๋ {len(failed)}๊ฐ:")
+ for folder in failed[:10]:
+ print(f" - {folder}")
+ if len(failed) > 10:
+ print(f" ... ์ธ {len(failed)-10}๊ฐ")
+
+ print("="*70)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Sapiens ์ธ๊ทธ๋ฉํ
์ด์
๊ธฐ๋ฐ Temporal Fluency ํ๊ฐ',
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+์์:
+ # ๋จ์ผ ํด๋ ํ๊ฐ
+ python evaluate_fluency.py /path/to/sapiens_seg --output ./results
+
+ # ์ฌ๋ฌ ํด๋ ์ผ๊ด ํ๊ฐ
+ python evaluate_fluency.py /path/to/parent --batch --output ./results
+
+ # FPS์ ์๋์ฐ ํฌ๊ธฐ ์ง์
+ python evaluate_fluency.py /path/to/sapiens_seg --fps 60 --window 3
+ """
+ )
+
+ parser.add_argument(
+ 'input',
+ help='์ธ๊ทธ๋ฉํ
์ด์
ํด๋ ๋๋ ๋ถ๋ชจ ํด๋ ๊ฒฝ๋ก'
+ )
+ parser.add_argument(
+ '--fps',
+ type=int,
+ default=30,
+ help='ํ๋ ์๋ ์ดํธ (๊ธฐ๋ณธ: 30)'
+ )
+ parser.add_argument(
+ '--window',
+ type=int,
+ default=5,
+ help='ํ๊ฐ ์๋์ฐ ํฌ๊ธฐ (์ด, ๊ธฐ๋ณธ: 5)'
+ )
+ parser.add_argument(
+ '--output',
+ help='๊ฒฐ๊ณผ ์ ์ฅ ๋๋ ํ ๋ฆฌ'
+ )
+ parser.add_argument(
+ '--batch',
+ action='store_true',
+ help='์ฌ๋ฌ ํด๋ ์ผ๊ด ์ฒ๋ฆฌ'
+ )
+ parser.add_argument(
+ '--pattern',
+ default='*/images/sapiens_seg',
+ help='๋ฐฐ์น ๋ชจ๋์์ ํด๋ ๊ฒ์ ํจํด (๊ธฐ๋ณธ: */images/sapiens_seg)'
+ )
+ parser.add_argument(
+ '--quiet',
+ action='store_true',
+ help='์ต์ ์ถ๋ ฅ ๋ชจ๋'
+ )
+
+ args = parser.parse_args()
+
+ # Evaluator ์์ฑ
+ evaluator = FluencyEvaluator(fps=args.fps, window_seconds=args.window)
+
+ # ํ๊ฐ ์คํ
+ if args.batch:
+ evaluator.batch_evaluate(
+ args.input,
+ args.output,
+ pattern=args.pattern
+ )
+ else:
+ evaluator.evaluate_folder(
+ args.input,
+ args.output,
+ verbose=not args.quiet
+ )
+
+
+if __name__ == '__main__':
+ main()
diff --git a/inference_basic.py b/inference_basic.py
index 608ea2b..39f3a5f 100644
--- a/inference_basic.py
+++ b/inference_basic.py
@@ -15,8 +15,9 @@
from animation.modules.pose_net import PoseNet
from animation.modules.unet import UNetSpatioTemporalConditionModel
from animation.pipelines.inference_pipeline_animation import InferenceAnimationPipeline
+from animation.modules.music_encoder import MusicEncoder
import random
-
+import h5py
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@@ -36,6 +37,34 @@ def load_images_from_folder(folder, width, height):
return images
+def load_music(filename, length):
+ with h5py.File(filename, "r") as f:
+ m = f["music"][:]
+ # Shape ์ ๋ฆฌ
+ if m.ndim == 3 and m.shape[0] == 1:
+ m = m.squeeze(0) # (1, T, 4800) -> (T, 4800)
+
+ if m.ndim == 2:
+ # (T, 4800) ํํ๊ฐ ๋๋๋ก ํ์ธ
+ if m.shape[1] != 4800 and m.shape[0] == 4800:
+ m = m.T # (4800, T) -> (T, 4800)
+
+ music_fea = torch.from_numpy(m).float()
+
+ T = music_fea.shape[0]
+
+ if T <= length:
+ # ๊ธธ์ด๊ฐ ๋ถ์กฑํ๋ฉด ์์์๋ถํฐ ์๋ผ๋ด๊ฑฐ๋ padding ํ์
+ start = 0
+ else:
+ # 1/4 ~ 3/4 ๊ตฌ๊ฐ์์ ๋๋ค ์์์ ์ ํ
+ quarter = T // 4
+ min_start = quarter
+ max_start = 3 * T // 4
+
+ return music_fea[:, :]
+
+
def save_frames_as_png(frames, output_path):
pil_frames = [Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame for frame in frames]
num_frames = len(pil_frames)
@@ -100,11 +129,11 @@ def parse_args():
),
)
parser.add_argument(
- "--validation_control_folder",
+ "--validation_music",
type=str,
default=None,
help=(
- "the validation control image"
+ "the validation control music"
),
)
@@ -118,7 +147,7 @@ def parse_args():
parser.add_argument(
"--height",
type=int,
- default=768,
+ default=512,
required=False
)
@@ -128,6 +157,13 @@ def parse_args():
default=512,
required=False
)
+
+ parser.add_argument(
+ "--length",
+ type=int,
+ default=16,
+ required=False
+ )
parser.add_argument(
"--guidance_scale",
@@ -144,7 +180,7 @@ def parse_args():
)
parser.add_argument(
- "--posenet_model_name_or_path",
+ "--music_encoder_model_name_or_path",
type=str,
default=None,
help="Path to pretrained posenet model",
@@ -182,6 +218,7 @@ def parse_args():
default=0.0, # or set to 0.02
required=False
)
+
parser.add_argument(
"--frames_overlap",
type=int,
@@ -233,7 +270,8 @@ def parse_args():
subfolder="unet",
low_cpu_mem_usage=True,
)
- pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0])
+
+ music_encoder = MusicEncoder()
face_encoder = FusionFaceId(
cross_attention_dim=1024,
id_embeddings_dim=512,
@@ -286,11 +324,11 @@ def parse_args():
unet.set_attn_processor(attn_procs)
# resume the previous checkpoint
- if args.posenet_model_name_or_path is not None and args.face_encoder_model_name_or_path is not None and args.unet_model_name_or_path is not None:
+ if args.music_encoder_model_name_or_path is not None and args.face_encoder_model_name_or_path is not None and args.unet_model_name_or_path is not None:
print("Loading existing posenet weights, face_encoder weights and unet weights.")
- if args.posenet_model_name_or_path.endswith(".pth"):
- pose_net_state_dict = torch.load(args.posenet_model_name_or_path, map_location="cpu")
- pose_net.load_state_dict(pose_net_state_dict, strict=True)
+ if args.music_encoder_model_name_or_path.endswith(".pth"):
+ music_encoder_state_dict = torch.load(args.music_encoder_model_name_or_path, map_location="cpu")
+ music_encoder.load_state_dict(music_encoder_state_dict, strict=True)
else:
print("posenet weights loading fail")
print(1/0)
@@ -311,7 +349,7 @@ def parse_args():
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
unet.requires_grad_(False)
- pose_net.requires_grad_(False)
+ music_encoder.requires_grad_(False)
face_encoder.requires_grad_(False)
if args.gradient_checkpointing:
@@ -327,74 +365,82 @@ def parse_args():
unet=unet,
scheduler=noise_scheduler,
feature_extractor=feature_extractor,
- pose_net=pose_net,
+ music_encoder=music_encoder,
face_encoder=face_encoder,
).to(device='cuda', dtype=weight_dtype)
os.makedirs(args.output_dir, exist_ok=True)
- validation_image_path = args.validation_image
- validation_image = Image.open(args.validation_image).convert('RGB')
- validation_control_images = load_images_from_folder(args.validation_control_folder, width=args.width, height=args.height)
-
- num_frames = len(validation_control_images)
- face_model.face_helper.clean_all()
- validation_face = cv2.imread(validation_image_path)
- validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR)
- validation_image_face_info = face_model.app.get(validation_image_bgr)
- if len(validation_image_face_info) > 0:
- validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
- validation_image_id_ante_embedding = validation_image_face_info['embedding']
- else:
- validation_image_id_ante_embedding = None
-
- if validation_image_id_ante_embedding is None:
- face_model.face_helper.read_image(validation_image_bgr)
- face_model.face_helper.get_face_landmarks_5(only_center_face=True)
- face_model.face_helper.align_warp_face()
-
- if len(face_model.face_helper.cropped_faces) == 0:
- validation_image_id_ante_embedding = np.zeros((512,))
+ mudic = {'/root/aist_hdf5/rec/gJS_sBM_cAll_d01_mJS0_ch07.h5' : 'Street Jazz',
+ '/root/aist_hdf5/rec/gKR_sBM_cAll_d29_mKR0_ch08.h5' : 'Krump',
+ '/root/aist_hdf5/rec/gLH_sBM_cAll_d16_mLH0_ch09.h5' : 'LA HipHop',
+ '/root/aist_hdf5/rec/gLO_sBM_cAll_d13_mLO0_ch03.h5' : 'Lock',
+ '/root/aist_hdf5/rec/gMH_sBM_cAll_d23_mMH1_ch03.h5' : 'Middle HipHop',
+ '/root/aist_hdf5/rec/gPO_sBM_cAll_d10_mPO0_ch03.h5' : 'Pop'}
+
+ for mu in mudic.keys():
+ validation_image_path = args.validation_image
+ validation_image = Image.open(args.validation_image).convert('RGB')
+ validation_music = load_music(mu, args.length)
+
+ num_frames = validation_music.shape[0]
+ face_model.face_helper.clean_all()
+ validation_face = cv2.imread(validation_image_path)
+ validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR)
+ validation_image_face_info = face_model.app.get(validation_image_bgr)
+ if len(validation_image_face_info) > 0:
+ validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
+ validation_image_id_ante_embedding = validation_image_face_info['embedding']
else:
- validation_image_align_face = face_model.face_helper.cropped_faces[0]
- print('fail to detect face using insightface, extract embedding on align face')
- validation_image_id_ante_embedding = face_model.handler_ante.get_feat(validation_image_align_face)
-
- # generator = torch.Generator(device=accelerator.device).manual_seed(23123134)
-
- decode_chunk_size = args.decode_chunk_size
- video_frames = pipeline(
- image=validation_image,
- image_pose=validation_control_images,
- height=args.height,
- width=args.width,
- num_frames=num_frames,
- tile_size=args.tile_size,
- tile_overlap=args.frames_overlap,
- decode_chunk_size=decode_chunk_size,
- motion_bucket_id=127.,
- fps=7,
- min_guidance_scale=args.guidance_scale,
- max_guidance_scale=args.guidance_scale,
- noise_aug_strength=args.noise_aug_strength,
- num_inference_steps=args.num_inference_steps,
- generator=generator,
- output_type="pil",
- validation_image_id_ante_embedding=validation_image_id_ante_embedding,
- ).frames[0]
-
- out_file = os.path.join(
- args.output_dir,
- f"animation_video.mp4",
- )
- for i in range(num_frames):
- img = video_frames[i]
- video_frames[i] = np.array(img)
+ validation_image_id_ante_embedding = None
- png_out_file = os.path.join(args.output_dir, "animated_images")
- os.makedirs(png_out_file, exist_ok=True)
- export_to_gif(video_frames, out_file, 8)
- save_frames_as_png(video_frames, png_out_file)
+ if validation_image_id_ante_embedding is None:
+ face_model.face_helper.read_image(validation_image_bgr)
+ face_model.face_helper.get_face_landmarks_5(only_center_face=True)
+ face_model.face_helper.align_warp_face()
+
+ if len(face_model.face_helper.cropped_faces) == 0:
+ validation_image_id_ante_embedding = np.zeros((512,))
+ else:
+ validation_image_align_face = face_model.face_helper.cropped_faces[0]
+ print('fail to detect face using insightface, extract embedding on align face')
+ validation_image_id_ante_embedding = face_model.handler_ante.get_feat(validation_image_align_face)
+
+ # generator = torch.Generator(device=accelerator.device).manual_seed(23123134)
+
+ decode_chunk_size = args.decode_chunk_size
+ video_frames = pipeline(
+ image=validation_image,
+ music=validation_music,
+ height=args.height,
+ width=args.width,
+ num_frames=num_frames,
+ tile_size=args.tile_size,
+ tile_overlap=args.frames_overlap,
+ decode_chunk_size=decode_chunk_size,
+ motion_bucket_id=127.,
+ fps=7,
+ min_guidance_scale=args.guidance_scale,
+ max_guidance_scale=args.guidance_scale,
+ noise_aug_strength=args.noise_aug_strength,
+ num_inference_steps=args.num_inference_steps,
+ generator=generator,
+ output_type="pil",
+ validation_image_id_ante_embedding=validation_image_id_ante_embedding,
+ ).frames[0]
+
+ out_file = os.path.join(
+ args.output_dir,
+ f"{mudic[mu]}.mp4",
+ )
+ for i in range(num_frames):
+ img = video_frames[i]
+ video_frames[i] = np.array(img)
+
+ png_out_file = os.path.join(args.output_dir, f"{mudic[mu]}")
+ os.makedirs(png_out_file, exist_ok=True)
+ export_to_gif(video_frames, out_file, 60)
+ save_frames_as_png(video_frames, png_out_file)
# bash command_basic_infer.sh
diff --git a/toy.py b/toy.py
new file mode 100644
index 0000000..45696d1
--- /dev/null
+++ b/toy.py
@@ -0,0 +1,28 @@
+import os
+from PIL import Image
+
+# ์
๋ ฅ ํด๋์ ์ถ๋ ฅ ํด๋
+input_dir = "/root/dataset/rec/gKR_sBM_cAll_d28_mKR1_ch07/images"
+output_dir = os.path.join(input_dir, "square_cropped")
+os.makedirs(output_dir, exist_ok=True)
+
+for fname in os.listdir(input_dir):
+ if fname.lower().endswith(".png"):
+ path = os.path.join(input_dir, fname)
+ img = Image.open(path)
+
+ w, h = img.size
+ # ์ ์ฌ๊ฐ ํฌ๊ธฐ = ์ธ๋ก๊ธธ์ด
+ square_size = h
+
+ # ๊ฐ๋ก ๊ธฐ์ค ์ค์ ์ ๋ ฌ
+ left = (w - square_size) // 2
+ top = 0
+ right = left + square_size
+ bottom = h
+
+ cropped = img.crop((left, top, right, bottom))
+ save_path = os.path.join(output_dir, fname)
+ cropped.save(save_path)
+
+print("โ
์ธ๋ก ๊ธธ์ด์ ๋ง์ถฐ ์ค์ ์ ์ฌ๊ฐ ํฌ๋กญ ์๋ฃ. ์ ์ฅ ์์น:", output_dir)
diff --git a/train.py b/train.py
deleted file mode 100644
index b56696d..0000000
--- a/train.py
+++ /dev/null
@@ -1,1695 +0,0 @@
-import argparse
-import random
-import logging
-import math
-import os
-
-import cv2
-import shutil
-from pathlib import Path
-from urllib.parse import urlparse
-import numpy as np
-import PIL
-from PIL import Image, ImageDraw
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from diffusers.models.attention_processor import XFormersAttnProcessor
-
-from animation.dataset.animation_dataset import LargeScaleAnimationVideos
-from animation.modules.attention_processor import AnimationAttnProcessor
-from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor
-from animation.modules.face_model import FaceModel
-from animation.modules.id_encoder import FusionFaceId
-from animation.modules.pose_net import PoseNet
-from animation.modules.unet import UNetSpatioTemporalConditionModel
-
-from animation.pipelines.validation_pipeline_animation import ValidationAnimationPipeline
-import transformers
-from accelerate import Accelerator, DistributedType
-from accelerate.logging import get_logger
-from accelerate.utils import ProjectConfiguration, set_seed
-from huggingface_hub import create_repo, upload_folder
-from packaging import version
-from tqdm.auto import tqdm
-from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
-from einops import rearrange
-
-import datetime
-import diffusers
-from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
-from diffusers.image_processor import VaeImageProcessor
-from diffusers.optimization import get_scheduler
-from diffusers.training_utils import EMAModel
-from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
-from diffusers.utils.import_utils import is_xformers_available
-import warnings
-import torch.nn as nn
-from diffusers.utils.torch_utils import randn_tensor
-
-# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
-
-logger = get_logger(__name__, log_level="INFO")
-
-#i should make a utility function file
-def validate_and_convert_image(image, target_size=(256, 256)):
- if image is None:
- print("Encountered a None image")
- return None
-
- if isinstance(image, torch.Tensor):
- # Convert PyTorch tensor to PIL Image
- if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
- if image.shape[0] == 1: # Convert single-channel grayscale to RGB
- image = image.repeat(3, 1, 1)
- image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
- image = Image.fromarray(image)
- else:
- print(f"Invalid image tensor shape: {image.shape}")
- return None
- elif isinstance(image, Image.Image):
- # Resize PIL Image
- image = image.resize(target_size)
- else:
- print("Image is not a PIL Image or a PyTorch tensor")
- return None
-
- return image
-
-def create_image_grid(images, rows, cols, target_size=(256, 256)):
- valid_images = [validate_and_convert_image(img, target_size) for img in images]
- valid_images = [img for img in valid_images if img is not None]
-
- if not valid_images:
- print("No valid images to create a grid")
- return None
-
- w, h = target_size
- grid = Image.new('RGB', size=(cols * w, rows * h))
-
- for i, image in enumerate(valid_images):
- grid.paste(image, box=((i % cols) * w, (i // cols) * h))
-
- return grid
-
-def save_combined_frames(batch_output, validation_images, validation_control_images,output_folder):
- # Flatten batch_output, which is a list of lists of PIL Images
- flattened_batch_output = [img for sublist in batch_output for img in sublist]
-
- # Combine frames into a list without converting (since they are already PIL Images)
- combined_frames = validation_images + validation_control_images + flattened_batch_output
-
- # Calculate rows and columns for the grid
- num_images = len(combined_frames)
- cols = 3 # adjust number of columns as needed
- rows = (num_images + cols - 1) // cols
- timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
-
- filename = f"combined_frames_{timestamp}.png"
- # Create and save the grid image
- grid = create_image_grid(combined_frames, rows, cols)
- output_folder = os.path.join(output_folder, "validation_images")
- os.makedirs(output_folder, exist_ok=True)
-
- # Now define the full path for the file
- timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
- filename = f"combined_frames_{timestamp}.png"
- output_loc = os.path.join(output_folder, filename)
-
- if grid is not None:
- grid.save(output_loc)
- else:
- print("Failed to create image grid")
-
-
-
-# def load_images_from_folder(folder):
-# images = []
-# valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
-#
-# # Function to extract frame number from the filename
-# def frame_number(filename):
-# # First, try the pattern 'frame_x_7fps'
-# new_pattern_match = re.search(r'frame_(\d+)_7fps', filename)
-# if new_pattern_match:
-# return int(new_pattern_match.group(1))
-# # If the new pattern is not found, use the original digit extraction method
-# matches = re.findall(r'\d+', filename)
-# if matches:
-# if matches[-1] == '0000' and len(matches) > 1:
-# return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
-# return int(matches[-1]) # Otherwise, return the last sequence
-# return float('inf') # Return 'inf'
-#
-# # Sorting files based on frame number
-# sorted_files = sorted(os.listdir(folder), key=frame_number)
-#
-# # Load images in sorted order
-# for filename in sorted_files:
-# ext = os.path.splitext(filename)[1].lower()
-# if ext in valid_extensions:
-# img = Image.open(os.path.join(folder, filename)).convert('RGB')
-# images.append(img)
-#
-# return images
-
-def load_images_from_folder(folder):
- images = []
-
- files = os.listdir(folder)
- png_files = [f for f in files if f.endswith('.png')]
- png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
- for filename in png_files:
- img = Image.open(os.path.join(folder, filename)).convert('RGB')
- images.append(img)
-
- return images
-
-
-
-# copy from https://github.com/crowsonkb/k-diffusion.git
-def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
- """Draws stratified samples from a uniform distribution."""
- if groups <= 0:
- raise ValueError(f"groups must be positive, got {groups}")
- if group < 0 or group >= groups:
- raise ValueError(f"group must be in [0, {groups})")
- n = shape[-1] * groups
- offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
- u = torch.rand(shape, dtype=dtype, device=device)
- return (offsets + u) / n
-
-
-def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3, device='cpu', dtype=torch.float32):
- """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
-
- def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
- t_min = math.atan(math.exp(-0.5 * logsnr_max))
- t_max = math.atan(math.exp(-0.5 * logsnr_min))
- return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
-
- def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
- shift = 2 * math.log(noise_d / image_d)
- return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
-
- def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
- logsnr_low = logsnr_schedule_cosine_shifted(
- t, image_d, noise_d_low, logsnr_min, logsnr_max)
- logsnr_high = logsnr_schedule_cosine_shifted(
- t, image_d, noise_d_high, logsnr_min, logsnr_max)
- return torch.lerp(logsnr_low, logsnr_high, t)
-
- logsnr_min = -2 * math.log(min_value / sigma_data)
- logsnr_max = -2 * math.log(max_value / sigma_data)
- u = stratified_uniform(
- shape, group=0, groups=1, dtype=dtype, device=device
- )
- logsnr = logsnr_schedule_cosine_interpolated(
- u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
- return torch.exp(-logsnr / 2) * sigma_data
-
-def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
- """Draws samples from an lognormal distribution."""
- u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
- return torch.distributions.Normal(loc, scale).icdf(u).exp()
-
-min_value = 0.002
-max_value = 700
-image_d = 64
-noise_d_low = 32
-noise_d_high = 64
-sigma_data = 0.5
-
-
-def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
- h, w = input.shape[-2:]
- factors = (h / size[0], w / size[1])
-
- # First, we have to determine sigma
- # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
- sigmas = (
- max((factors[0] - 1.0) / 2.0, 0.001),
- max((factors[1] - 1.0) / 2.0, 0.001),
- )
-
- # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
- # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
- # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
- ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
-
- # Make sure it is odd
- if (ks[0] % 2) == 0:
- ks = ks[0] + 1, ks[1]
-
- if (ks[1] % 2) == 0:
- ks = ks[0], ks[1] + 1
-
- input = _gaussian_blur2d(input, ks, sigmas)
-
- output = torch.nn.functional.interpolate(
- input, size=size, mode=interpolation, align_corners=align_corners)
- return output
-
-
-def _compute_padding(kernel_size):
- """Compute padding tuple."""
- # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
- # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
- if len(kernel_size) < 2:
- raise AssertionError(kernel_size)
- computed = [k - 1 for k in kernel_size]
-
- # for even kernels we need to do asymmetric padding :(
- out_padding = 2 * len(kernel_size) * [0]
-
- for i in range(len(kernel_size)):
- computed_tmp = computed[-(i + 1)]
-
- pad_front = computed_tmp // 2
- pad_rear = computed_tmp - pad_front
-
- out_padding[2 * i + 0] = pad_front
- out_padding[2 * i + 1] = pad_rear
-
- return out_padding
-
-
-def _filter2d(input, kernel):
- # prepare kernel
- b, c, h, w = input.shape
- tmp_kernel = kernel[:, None, ...].to(
- device=input.device, dtype=input.dtype)
-
- tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
-
- height, width = tmp_kernel.shape[-2:]
-
- padding_shape: list[int] = _compute_padding([height, width])
- input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
-
- # kernel and input tensor reshape to align element-wise or batch-wise params
- tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
- input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
-
- # convolve the tensor with the kernel.
- output = torch.nn.functional.conv2d(
- input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
-
- out = output.view(b, c, h, w)
- return out
-
-
-def _gaussian(window_size: int, sigma):
- if isinstance(sigma, float):
- sigma = torch.tensor([[sigma]])
-
- batch_size = sigma.shape[0]
-
- x = (torch.arange(window_size, device=sigma.device,
- dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
-
- if window_size % 2 == 0:
- x = x + 0.5
-
- gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
-
- return gauss / gauss.sum(-1, keepdim=True)
-
-
-def _gaussian_blur2d(input, kernel_size, sigma):
- if isinstance(sigma, tuple):
- sigma = torch.tensor([sigma], dtype=input.dtype)
- else:
- sigma = sigma.to(dtype=input.dtype)
-
- ky, kx = int(kernel_size[0]), int(kernel_size[1])
- bs = sigma.shape[0]
- kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
- kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
- out_x = _filter2d(input, kernel_x[..., None, :])
- out = _filter2d(out_x, kernel_y[..., None])
-
- return out
-
-
-def export_to_video(video_frames, output_video_path, fps):
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
- h, w, _ = video_frames[0].shape
- video_writer = cv2.VideoWriter(
- output_video_path, fourcc, fps=fps, frameSize=(w, h))
- for i in range(len(video_frames)):
- img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
- video_writer.write(img)
-
-
-def export_to_gif(frames, output_gif_path, fps):
- """
- Export a list of frames to a GIF.
-
- Args:
- - frames (list): List of frames (as numpy arrays or PIL Image objects).
- - output_gif_path (str): Path to save the output GIF.
- - duration_ms (int): Duration of each frame in milliseconds.
-
- """
- # Convert numpy arrays to PIL Images if needed
- pil_frames = [Image.fromarray(frame) if isinstance(
- frame, np.ndarray) else frame for frame in frames]
-
- pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
- format='GIF',
- append_images=pil_frames[1:],
- save_all=True,
- duration=125,
- loop=0)
-
-
-def tensor_to_vae_latent(t, vae, scale=True):
- t = t.to(vae.dtype)
- if len(t.shape) == 5:
- video_length = t.shape[1]
-
- t = rearrange(t, "b f c h w -> (b f) c h w")
- latents = vae.encode(t).latent_dist.sample()
- latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
- elif len(t.shape) == 4:
- latents = vae.encode(t).latent_dist.sample()
- if scale:
- latents = latents * vae.config.scaling_factor
- return latents
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description="Script to train Stable Diffusion XL for InstructPix2Pix."
- )
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
-
- parser.add_argument(
- "--num_frames",
- type=int,
- default=14,
- )
- parser.add_argument(
- "--dataset_type",
- type=str,
- default='ubc',
- )
- parser.add_argument(
- "--num_validation_images",
- type=int,
- default=1,
- help="Number of images that should be generated during validation with `validation_prompt`.",
- )
- parser.add_argument(
- "--validation_steps",
- type=int,
- default=500,
- help=(
- "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt"
- " multiple times: `args.num_validation_images`."
- ),
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="./outputs",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--seed", type=int, default=None, help="A seed for reproducible training."
- )
- parser.add_argument(
- "--per_gpu_batch_size",
- type=int,
- default=1,
- help="Batch size (per device) for the training dataloader.",
- )
- parser.add_argument("--num_train_epochs", type=int, default=100)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps",
- type=int,
- default=500,
- help="Number of steps for the warmup in the lr scheduler.",
- )
- parser.add_argument(
- "--conditioning_dropout_prob",
- type=float,
- default=0.1,
- help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
- )
- parser.add_argument(
- "--use_8bit_adam",
- action="store_true",
- help="Whether or not to use 8-bit Adam from bitsandbytes.",
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--use_ema", action="store_true", help="Whether to use EMA model."
- )
- parser.add_argument(
- "--non_ema_revision",
- type=str,
- default=None,
- required=False,
- help=(
- "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
- " remote repository specified with --pretrained_model_name_or_path."
- ),
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=8,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
- parser.add_argument(
- "--adam_beta1",
- type=float,
- default=0.9,
- help="The beta1 parameter for the Adam optimizer.",
- )
- parser.add_argument(
- "--adam_beta2",
- type=float,
- default=0.999,
- help="The beta2 parameter for the Adam optimizer.",
- )
- parser.add_argument(
- "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
- )
- parser.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer",
- )
- parser.add_argument(
- "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
- )
- parser.add_argument(
- "--push_to_hub",
- action="store_true",
- help="Whether or not to push the model to the Hub.",
- )
- parser.add_argument(
- "--hub_token",
- type=str,
- default=None,
- help="The token to use to push to the Model Hub.",
- )
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help=(
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
- ),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default="tensorboard",
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- parser.add_argument(
- "--local_rank",
- type=int,
- default=-1,
- help="For distributed training: local_rank",
- )
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- parser.add_argument(
- "--checkpoints_total_limit",
- type=int,
- default=1,
- help=("Max number of checkpoints to store."),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--enable_xformers_memory_efficient_attention",
- action="store_true",
- help="Whether or not to use xformers.",
- )
- parser.add_argument(
- "--log_trainable_parameters",
- action="store_true",
- help="Whether to write the trainable parameters.",
- )
- parser.add_argument(
- "--pretrain_unet",
- type=str,
- default=None,
- help="use weight for unet block",
- )
- parser.add_argument(
- "--rank",
- type=int,
- default=128,
- help=("The dimension of the LoRA update matrices."),
- )
- parser.add_argument(
- "--csv_path",
- type=str,
- default=None,
- help=(
- "path to the dataset csv"
- ),
- )
- parser.add_argument(
- "--video_folder",
- type=str,
- default=None,
- help=(
- "path to the video folder"
- ),
- )
- parser.add_argument(
- "--condition_folder",
- type=str,
- default=None,
- help=(
- "path to the depth folder"
- ),
- )
- parser.add_argument(
- "--motion_folder",
- type=str,
- default=None,
- help=(
- "path to the depth folder"
- ),
- )
- parser.add_argument(
- "--validation_prompt",
- type=str,
- default=None,
- help=(
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
- ),
- )
- parser.add_argument(
- "--validation_image_folder",
- type=str,
- default=None,
- help=(
- "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
- " `--validation_image` that will be used with all `--validation_prompt`s."
- ),
- )
- parser.add_argument(
- "--validation_image",
- type=str,
- default=None,
- help=(
- "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
- " `--validation_image` that will be used with all `--validation_prompt`s."
- ),
- )
- parser.add_argument(
- "--validation_control_folder",
- type=str,
- default=None,
- help=(
- "the validation control image"
- ),
- )
- parser.add_argument(
- "--sample_n_frames",
- type=int,
- default=14,
- help=(
- "the sample_n_frames"
- ),
- )
-
- parser.add_argument(
- "--ref_augment",
- action="store_true",
- help=(
- "use augment for the reference image"
- ),
- )
- parser.add_argument(
- "--train_stage",
- type=int,
- default=2,
- help=(
- "the training stage"
- ),
- )
-
- parser.add_argument(
- "--posenet_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained posenet model",
- )
- parser.add_argument(
- "--face_encoder_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained face encoder model",
- )
- parser.add_argument(
- "--unet_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained unet model",
- )
-
- parser.add_argument(
- "--data_root_path",
- type=str,
- default=None,
- help="Path to the data root path",
- )
- parser.add_argument(
- "--rec_data_path",
- type=str,
- default=None,
- help="Path to the rec data path",
- )
- parser.add_argument(
- "--vec_data_path",
- type=str,
- default=None,
- help="Path to the vec data path",
- )
-
- parser.add_argument(
- "--finetune_mode",
- type=bool,
- default=False,
- help="Enable or disable the finetune mode (True/False).",
- )
- parser.add_argument(
- "--posenet_model_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained posenet model",
- )
- parser.add_argument(
- "--face_encoder_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained face encoder",
- )
- parser.add_argument(
- "--unet_model_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained unet model",
- )
-
- args = parser.parse_args()
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
- if env_local_rank != -1 and env_local_rank != args.local_rank:
- args.local_rank = env_local_rank
-
- # default to using the same revision for the non-ema model if not specified
- if args.non_ema_revision is None:
- args.non_ema_revision = args.revision
-
- return args
-
-
-def download_image(url):
- original_image = (
- lambda image_url_or_path: load_image(image_url_or_path)
- if urlparse(image_url_or_path).scheme
- else PIL.Image.open(image_url_or_path).convert("RGB")
- )(url)
- return original_image
-
-
-# This is for training using deepspeed.
-# Since now the DeepSpeed only supports trainging with only one model
-# So we create a virtual wrapper to contail all the models
-
-class DeepSpeedWrapperModel(nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- for name, value in kwargs.items():
- assert isinstance(value, nn.Module)
- self.register_module(name, value)
-
-
-def main():
-
- warnings.filterwarnings('ignore', category=DeprecationWarning)
- warnings.filterwarnings('ignore', category=FutureWarning)
- torch.multiprocessing.set_start_method('spawn')
-
- args = parse_args()
-
- if args.non_ema_revision is not None:
- deprecate(
- "non_ema_revision!=None",
- "0.15.0",
- message=(
- "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
- " use `--variant=non_ema` instead."
- ),
- )
- logging_dir = os.path.join(args.output_dir, args.logging_dir)
- accelerator_project_config = ProjectConfiguration(
- project_dir=args.output_dir, logging_dir=logging_dir)
- # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- project_config=accelerator_project_config,
- )
-
- generator = torch.Generator(
- device=accelerator.device).manual_seed(23123134)
-
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError(
- "Make sure to install wandb if you want to use it for logging during training.")
- import wandb
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
-
- # Load scheduler, tokenizer and models.
- print(args.pretrained_model_name_or_path)
- feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor", revision=args.revision)
- noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision
- )
- vae = AutoencoderKLTemporalDecoder.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16")
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
- args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet,
- subfolder="unet",
- low_cpu_mem_usage=True,
- variant="fp16"
- )
- pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0])
- face_encoder = FusionFaceId(
- cross_attention_dim=1024,
- id_embeddings_dim=512,
- clip_embeddings_dim=1024,
- num_tokens=4,)
- face_model = FaceModel()
-
- # init adapter modules
- lora_rank = 128
- attn_procs = {}
- unet_svd = unet.state_dict()
-
- for name in unet.attn_processors.keys():
- if "transformer_blocks" in name and "temporal_transformer_blocks" not in name:
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- # print(f"This is AnimationAttnProcessor: {name}")
- attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
- else:
- # print(f"This is AnimationIDAttnNormalizedProcessor: {name}")
- layer_name = name.split(".processor")[0]
- weights = {
- "id_to_k.weight": unet_svd[layer_name + ".to_k.weight"],
- "id_to_v.weight": unet_svd[layer_name + ".to_v.weight"],
- }
- attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank)
- attn_procs[name].load_state_dict(weights, strict=False)
- elif "temporal_transformer_blocks" in name:
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- attn_procs[name] = XFormersAttnProcessor()
- else:
- attn_procs[name] = XFormersAttnProcessor()
- unet.set_attn_processor(attn_procs)
-
- # triggering the finetune mode
- if args.finetune_mode is True and args.posenet_model_finetune_path is not None and args.face_encoder_finetune_path is not None and args.unet_model_finetune_path is not None:
- print("Loading existing posenet weights, face_encoder weights and unet weights.")
- if args.posenet_model_finetune_path.endswith(".pth"):
- pose_net_state_dict = torch.load(args.posenet_model_finetune_path, map_location="cpu")
- pose_net.load_state_dict(pose_net_state_dict, strict=True)
- else:
- print("posenet weights loading fail")
- print(1/0)
- if args.face_encoder_finetune_path.endswith(".pth"):
- face_encoder_state_dict = torch.load(args.face_encoder_finetune_path, map_location="cpu")
- face_encoder.load_state_dict(face_encoder_state_dict, strict=True)
- else:
- print("face_encoder weights loading fail")
- print(1/0)
- if args.unet_model_finetune_path.endswith(".pth"):
- unet_state_dict = torch.load(args.unet_model_finetune_path, map_location="cpu")
- unet.load_state_dict(unet_state_dict, strict=True)
- else:
- print("unet weights loading fail")
- print(1/0)
-
-
- vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
-
- # Freeze vae and image_encoder
- vae.requires_grad_(False)
- image_encoder.requires_grad_(False)
- unet.requires_grad_(False)
- pose_net.requires_grad_(False)
- face_encoder.requires_grad_(False)
-
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- image_encoder.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
-
- if args.use_ema:
- ema_unet = EMAModel(unet.parameters(
- ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config)
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
- xformers_version = version.parse(xformers.__version__)
- if xformers_version == version.parse("0.0.16"):
- logger.warn(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
- )
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError(
- "xformers is not available. Make sure it is installed correctly")
-
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
-
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps *
- args.per_gpu_batch_size * accelerator.num_processes
- )
-
- # Initialize the optimizer
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
- )
-
- optimizer_cls = bnb.optim.AdamW8bit
- else:
- optimizer_cls = torch.optim.AdamW
-
- # if accelerator.distributed_type == DistributedType.DEEPSPEED:
- # ds_wrapper = DeepSpeedWrapperModel(
- # unet=unet,
- # controlnext=controlnext
- # )
- # unet = ds_wrapper.unet
- # controlnext = ds_wrapper.controlnext
-
-
- pose_net.requires_grad_(True)
- face_encoder.requires_grad_(True)
-
- parameters_list = []
-
- for name, para in pose_net.named_parameters():
- para.requires_grad = True
- parameters_list.append({"params": para, "lr": args.learning_rate } )
-
- for name, para in face_encoder.named_parameters():
- para.requires_grad = True
- parameters_list.append({"params": para, "lr": args.learning_rate } )
-
-
- """
- For more details, please refer to: https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333
- This is the selective parameters part.
- As presented in our paper, we only select a small subset of parameters, which is fully adapted to the SD1.5 and SDXL backbones. By training fewer than 100 million parameters, we still achieve excellent performance. But this is is not suitable for the SD3 and SVD training. This is because, after SDXL, Stability faced significant legal risks due to the generation of highly realistic human images. After that, they stopped refining their models on human-related data, such as SVD and SD3, to avoid potential risks.
- To achieve optimal performance, it's necessary to first continue training SVD and SD3 on human-related data to develop a robust backbone before fine-tuning. Of course, you can also combine the continual pretraining and finetuning. So you can find that we direct provide the full SVD parameters.
- We have experimented with two approaches: 1.Directly training the model from scratch on human dancing data. 2. Continual training using a pre-trained human generation backbone, followed by fine-tuning a selective small subset of parameters. Interestingly, we observed no significant difference in performance between these two methods.
- """
-
- for name, para in unet.named_parameters():
- if "attentions" in name:
- para.requires_grad = True
- parameters_list.append({"params": para})
- else:
- para.requires_grad = False
-
- optimizer = optimizer_cls(
- parameters_list,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
-
- # check para
- if accelerator.is_main_process and args.log_trainable_parameters:
- rec_txt1 = open('rec_para.txt', 'w')
- rec_txt2 = open('rec_para_train.txt', 'w')
- for name, para in unet.named_parameters():
- if para.requires_grad is False:
- rec_txt1.write(f'{name}\n')
- else:
- rec_txt2.write(f'{name}\n')
- rec_txt1.close()
- rec_txt2.close()
- # DataLoaders creation:
- args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
-
- root_path = args.data_root_path
- txt_path_1 = args.rec_data_path
- txt_path_2 = args.vec_data_path
- train_dataset_1 = LargeScaleAnimationVideos(
- root_path=root_path,
- txt_path=txt_path_1,
- width=512,
- height=512,
- n_sample_frames=args.sample_n_frames,
- sample_frame_rate=4,
- app=face_model.app,
- handler_ante=face_model.handler_ante,
- face_helper=face_model.face_helper
- )
- train_dataloader_1 = torch.utils.data.DataLoader(
- train_dataset_1,
- batch_size=args.per_gpu_batch_size,
- num_workers=args.num_workers,
- shuffle=True,
- )
- train_dataset_2 = LargeScaleAnimationVideos(
- root_path=root_path,
- txt_path=txt_path_2,
- width=576,
- height=1024,
- n_sample_frames=args.sample_n_frames,
- sample_frame_rate=4,
- app=face_model.app,
- handler_ante=face_model.handler_ante,
- face_helper=face_model.face_helper
- )
- train_dataloader_2 = torch.utils.data.DataLoader(
- train_dataset_2,
- batch_size=args.per_gpu_batch_size,
- num_workers=args.num_workers,
- shuffle=True
- )
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil((len(train_dataloader_1) + len(train_dataloader_2)) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
- )
-
- unet, pose_net, face_encoder, optimizer, lr_scheduler, train_dataloader_1, train_dataloader_2 = accelerator.prepare(
- unet, pose_net, face_encoder, optimizer, lr_scheduler, train_dataloader_1, train_dataloader_2
- )
-
- if args.use_ema:
- ema_unet.to(accelerator.device)
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil((len(train_dataloader_1) + len(train_dataloader_2)) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(
- args.max_train_steps / num_update_steps_per_epoch)
-
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- accelerator.init_trackers("StableAnimator", config=vars(args))
-
- # Train!
- total_batch_size = args.per_gpu_batch_size * \
- accelerator.num_processes * args.gradient_accumulation_steps
-
- len_zeros = len(train_dataloader_1)
- len_ones = len(train_dataloader_2)
-
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_dataset_1)+len(train_dataset_2)}")
- logger.info(f" Num Epochs = {args.num_train_epochs}")
- logger.info(
- f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
- logger.info(
- f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(
- f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
- global_step = 0
- first_epoch = 0
-
- def encode_image(pixel_values):
- pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
- pixel_values = (pixel_values + 1.0) / 2.0
-
- pixel_values = pixel_values.to(torch.float32)
- # Normalize the image with for CLIP input
- pixel_values = feature_extractor(
- images=pixel_values,
- do_normalize=True,
- do_center_crop=False,
- do_resize=False,
- do_rescale=False,
- return_tensors="pt",
- ).pixel_values
-
- pixel_values = pixel_values.to(
- device=accelerator.device, dtype=image_encoder.dtype)
- image_embeddings = image_encoder(pixel_values).image_embeds
- image_embeddings= image_embeddings.unsqueeze(1)
- return image_embeddings
-
-
- def _get_add_time_ids(
- fps,
- motion_bucket_id,
- noise_aug_strength,
- dtype,
- batch_size,
- unet=None,
- device=None
- ):
- add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
-
-
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=device)
- add_time_ids = add_time_ids.repeat(batch_size, 1)
- return add_time_ids
-
- # Potentially load in the weights and states from a previous save
- if args.resume_from_checkpoint:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the most recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- accelerator.print(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- args.resume_from_checkpoint = None
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
- global_step = int(path.split("-")[1])
-
- resume_global_step = global_step * args.gradient_accumulation_steps
- first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (
- num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps),
- disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
-
- for epoch in range(first_epoch, args.num_train_epochs):
- pose_net.train()
- face_encoder.train()
- unet.train()
- train_loss = 0.0
-
- iter1 = iter(train_dataloader_1)
- iter2 = iter(train_dataloader_2)
- list_0_and_1 = [0] * len_zeros + [1] * len_ones
- random.shuffle(list_0_and_1)
- for step in range(0, len(list_0_and_1)):
- current_idx = list_0_and_1[step]
- if current_idx == 0:
- try:
- batch = next(iter1)
- except StopIteration:
- iter1 = iter(train_dataloader_1)
- batch = next(iter1)
- elif current_idx == 1:
- try:
- batch = next(iter2)
- except StopIteration:
- iter2 = iter(train_dataloader_2)
- batch = next(iter2)
-
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
- with accelerator.accumulate(pose_net, face_encoder, unet):
- with accelerator.autocast():
- pixel_values = batch["pixel_values"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
- conditional_pixel_values = batch["reference_image"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
-
- latents = tensor_to_vae_latent(pixel_values, vae).to(dtype=weight_dtype)
-
- # Get the text embedding for conditioning.
- encoder_hidden_states = encode_image(conditional_pixel_values).to(dtype=weight_dtype)
- image_embed = encoder_hidden_states.clone()
-
- train_noise_aug = 0.02
- conditional_pixel_values = conditional_pixel_values + train_noise_aug * \
- randn_tensor(conditional_pixel_values.shape, generator=generator, device=conditional_pixel_values.device, dtype=conditional_pixel_values.dtype)
- conditional_latents = tensor_to_vae_latent(conditional_pixel_values, vae, scale=False)
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents)
- bsz = latents.shape[0]
- # Sample a random timestep for each image
- sigmas = rand_cosine_interpolated(shape=[bsz,], image_d=image_d, noise_d_low=noise_d_low, noise_d_high=noise_d_high, sigma_data=sigma_data, min_value=min_value, max_value=max_value).to(latents.device, dtype=weight_dtype)
-
- # sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents)
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- sigmas_reshaped = sigmas.clone()
- while len(sigmas_reshaped.shape) < len(latents.shape):
- sigmas_reshaped = sigmas_reshaped.unsqueeze(-1)
-
-
- noisy_latents = latents + noise * sigmas_reshaped
-
- timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(latents.device, dtype=weight_dtype)
-
-
- inp_noisy_latents = noisy_latents / ((sigmas_reshaped**2 + 1) ** 0.5)
-
- added_time_ids = _get_add_time_ids(
- fps=6,
- motion_bucket_id=127.0,
- noise_aug_strength=train_noise_aug, # noise_aug_strength == 0.0
- dtype=encoder_hidden_states.dtype,
- batch_size=bsz,
- unet=unet,
- device=latents.device
- )
-
- added_time_ids = added_time_ids.to(latents.device)
-
- # Conditioning dropout to support classifier-free guidance during inference. For more details
- # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
- if args.conditioning_dropout_prob is not None:
- random_p = torch.rand(
- bsz, device=latents.device, generator=generator)
- # Sample masks for the edit prompts.
- prompt_mask = random_p < 2 * args.conditioning_dropout_prob
- prompt_mask = prompt_mask.reshape(bsz, 1, 1)
- # Final text conditioning.
- null_conditioning = torch.zeros_like(encoder_hidden_states)
- encoder_hidden_states = torch.where(
- prompt_mask, null_conditioning, encoder_hidden_states)
-
- # Sample masks for the original images.
- image_mask_dtype = conditional_latents.dtype
- image_mask = 1 - (
- (random_p >= args.conditioning_dropout_prob).to(
- image_mask_dtype)
- * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
- )
- image_mask = image_mask.reshape(bsz, 1, 1, 1)
- # Final image conditioning.
- conditional_latents = image_mask * conditional_latents
-
- # Concatenate the `conditional_latents` with the `noisy_latents`.
- conditional_latents = conditional_latents.unsqueeze(
- 1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
-
- pose_pixels = batch["pose_pixels"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- faceid_embeds = batch["faceid_embeds"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- pose_latents = pose_net(pose_pixels)
-
- # print("This is faceid_latents calculation")
- # print(faceid_embeds.size()) # [1, 512]
- # print(image_embed.size()) # [1, 1, 1024]
-
- faceid_latents = face_encoder(faceid_embeds, image_embed)
-
-
- inp_noisy_latents = torch.cat(
- [inp_noisy_latents, conditional_latents], dim=2)
- target = latents
-
- # print(f"the size of encoder_hidden_states: {encoder_hidden_states.size()}") # [1, 1, 1024]
- # print(f"the size of face latents: {faceid_latents.size()}") # [1, 4, 1024]
- encoder_hidden_states = torch.cat([encoder_hidden_states, faceid_latents], dim=1)
-
- encoder_hidden_states = encoder_hidden_states.to(latents.dtype)
- inp_noisy_latents = inp_noisy_latents.to(latents.dtype)
- pose_latents = pose_latents.to(latents.dtype)
-
- # Predict the noise residual
- model_pred = unet(
- inp_noisy_latents, timesteps, encoder_hidden_states,
- added_time_ids=added_time_ids,
- pose_latents=pose_latents,
- ).sample
-
-
- sigmas = sigmas_reshaped
- # Denoise the latents
- c_out = -sigmas / ((sigmas**2 + 1)**0.5)
- c_skip = 1 / (sigmas**2 + 1)
- denoised_latents = model_pred * c_out + c_skip * noisy_latents
- weighing = (1 + sigmas ** 2) * (sigmas**-2.0)
-
- tgt_face_masks = batch["tgt_face_masks"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- tgt_face_masks = rearrange(tgt_face_masks, "b f c h w -> (b f) c h w")
- tgt_face_masks = F.interpolate(tgt_face_masks, size=(target.size()[-2], target.size()[-1]), mode='nearest')
- tgt_face_masks = rearrange(tgt_face_masks, "(b f) c h w -> b f c h w", f=args.sample_n_frames)
-
- # MSE loss
- loss = torch.mean(
- (weighing.float() * (denoised_latents.float() -
- target.float()) ** 2 * (1 + tgt_face_masks)).reshape(target.shape[0], -1),
- dim=1,
- )
- loss = loss.mean()
-
- # Gather the losses across all processes for logging (if we use distributed training).
- avg_loss = accelerator.gather(
- loss.repeat(args.per_gpu_batch_size)).mean()
- train_loss += avg_loss.item() / args.gradient_accumulation_steps
-
- # Backpropagate
- accelerator.backward(loss)
- # if accelerator.sync_gradients:
- # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- with torch.cuda.device(latents.device):
- torch.cuda.empty_cache()
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- if args.use_ema:
- ema_unet.step(unet.parameters())
- progress_bar.update(1)
- global_step += 1
- accelerator.log({"train_loss": train_loss}, step=global_step)
- train_loss = 0.0
-
- # save checkpoints!
- # if global_step % args.checkpointing_steps == 0 and (accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED):
- if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process:
-
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None and accelerator.is_main_process:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [
- d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(
- checkpoints, key=lambda x: int(x.split("-")[1]))
-
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(
- checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
-
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(
- f"removing checkpoints: {', '.join(removing_checkpoints)}")
-
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(
- args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
-
- save_path = os.path.join(
- args.output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- unwrap_unet = accelerator.unwrap_model(unet)
- unwrap_pose_net = accelerator.unwrap_model(pose_net)
- unwrap_face_encoder = accelerator.unwrap_model(face_encoder)
- unwrap_unet_state_dict = unwrap_unet.state_dict()
- torch.save(unwrap_unet_state_dict, os.path.join(args.output_dir, f"checkpoint-{global_step}", f"unet-{global_step}.pth"))
- unwrap_pose_net_state_dict = unwrap_pose_net.state_dict()
- torch.save(unwrap_pose_net_state_dict, os.path.join(args.output_dir, f"checkpoint-{global_step}", f"pose_net-{global_step}.pth"))
- unwrap_face_encoder_state_dict = unwrap_face_encoder.state_dict()
- torch.save(unwrap_face_encoder_state_dict, os.path.join(args.output_dir, f"checkpoint-{global_step}", f"face_encoder-{global_step}.pth"))
- logger.info(f"Saved state to {save_path}")
-
- if accelerator.is_main_process:
- # sample images!
- if global_step % args.validation_steps == 0:
- logger.info(
- f"Running validation... \n Generating {args.num_validation_images} videos."
- )
- # create pipeline
- if args.use_ema:
- # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
- ema_unet.store(unet.parameters())
- ema_unet.copy_to(unet.parameters())
-
- log_validation(
- vae=vae,
- image_encoder=image_encoder,
- unet=unet,
- pose_net=pose_net,
- face_encoder=face_encoder,
- app=face_model.app,
- face_helper=face_model.face_helper,
- handler_ante=face_model.handler_ante,
- scheduler=noise_scheduler,
- accelerator=accelerator,
- feature_extractor=feature_extractor,
- width=512,
- height=512,
- torch_dtype=weight_dtype,
- validation_image_folder=args.validation_image_folder,
- validation_image=args.validation_image,
- validation_control_folder=args.validation_control_folder,
- output_dir=args.output_dir,
- generator=generator,
- global_step=global_step,
- num_validation_cases=1,
- )
-
- if args.use_ema:
- # Switch back to the original UNet parameters.
- ema_unet.restore(unet.parameters())
-
- with torch.cuda.device(latents.device):
- torch.cuda.empty_cache()
-
- logs = {"step_loss": loss.detach().item(
- ), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
-
- # save checkpoints!
- # if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
- if accelerator.is_main_process:
- save_path = os.path.join(
- args.output_dir, f"checkpoint-last")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
-
-
-def log_validation(
- vae,
- image_encoder,
- unet,
- pose_net,
- face_encoder,
- app,
- face_helper,
- handler_ante,
- scheduler,
- accelerator,
- feature_extractor,
- width,
- height,
- torch_dtype,
- validation_image_folder,
- validation_image,
- validation_control_folder,
- output_dir,
- generator,
- global_step,
- num_validation_cases=1,
-):
- logger.info("Running validation... ")
- validation_unet = accelerator.unwrap_model(unet)
- validation_image_encoder = accelerator.unwrap_model(image_encoder)
- validation_vae = accelerator.unwrap_model(vae)
- validation_pose_net = accelerator.unwrap_model(pose_net)
- validation_face_encoder = accelerator.unwrap_model(face_encoder)
-
- pipeline = ValidationAnimationPipeline(
- vae=validation_vae,
- image_encoder=validation_image_encoder,
- unet=validation_unet,
- scheduler=scheduler,
- feature_extractor=feature_extractor,
- pose_net=validation_pose_net,
- face_encoder=validation_face_encoder,
- )
- pipeline = pipeline.to(accelerator.device)
- validation_images = load_images_from_folder(validation_image_folder)
- validation_image_path = validation_image
- if validation_image is None:
- validation_image = validation_images[0]
- else:
- validation_image = Image.open(validation_image).convert('RGB')
- validation_control_images = load_images_from_folder(validation_control_folder)
-
- val_save_dir = os.path.join(output_dir, "validation_images")
- if not os.path.exists(val_save_dir):
- os.makedirs(val_save_dir)
-
- with accelerator.autocast():
- for val_img_idx in range(num_validation_cases):
- # num_frames = args.num_frames
- num_frames = len(validation_control_images)
-
- face_helper.clean_all()
- validation_face = cv2.imread(validation_image_path)
- # validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR)
- validation_image_face_info = app.get(validation_face)
- if len(validation_image_face_info) > 0:
- validation_image_face_info = sorted(validation_image_face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
- validation_image_id_ante_embedding = validation_image_face_info['embedding']
- else:
- validation_image_id_ante_embedding = None
-
- if validation_image_id_ante_embedding is None:
- face_helper.read_image(validation_face)
- face_helper.get_face_landmarks_5(only_center_face=True)
- face_helper.align_warp_face()
-
- if len(face_helper.cropped_faces) == 0:
- validation_image_id_ante_embedding = np.zeros((512,))
- else:
- validation_image_align_face = face_helper.cropped_faces[0]
- print('fail to detect face using insightface, extract embedding on align face')
- validation_image_id_ante_embedding = handler_ante.get_feat(validation_image_align_face)
-
- video_frames = pipeline(
- image=validation_image,
- image_pose=validation_control_images,
- height=height,
- width=width,
- num_frames=num_frames,
- tile_size=num_frames,
- tile_overlap=4,
- decode_chunk_size=4,
- motion_bucket_id=127.,
- fps=7,
- min_guidance_scale=3,
- max_guidance_scale=3,
- noise_aug_strength=0.02,
- num_inference_steps=25,
- generator=generator,
- output_type="pil",
- validation_image_id_ante_embedding=validation_image_id_ante_embedding,
- ).frames[0]
- # save_combined_frames(video_frames, validation_images, validation_control_images, val_save_dir)
-
- out_file = os.path.join(
- val_save_dir,
- f"step_{global_step}_val_img_{val_img_idx}.mp4",
- )
- # print(video_frames.size()) # [16, 3, 512, 512]
- for i in range(num_frames):
- img = video_frames[i]
- video_frames[i] = np.array(img)
- export_to_gif(video_frames, out_file, 8)
-
- del pipeline
- torch.cuda.empty_cache()
-
-
-
-if __name__ == "__main__":
- main()
diff --git a/train_single.py b/train_single.py
index 1cb5ddc..7672bfa 100644
--- a/train_single.py
+++ b/train_single.py
@@ -1,1670 +1,1621 @@
-import argparse
-import random
-import logging
-import math
-import os
-
-import cv2
-import shutil
-from pathlib import Path
-from urllib.parse import urlparse
-import numpy as np
-import PIL
-from PIL import Image, ImageDraw
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from diffusers.models.attention_processor import XFormersAttnProcessor
-
-from animation.dataset.animation_dataset import LargeScaleAnimationVideos
-from animation.modules.attention_processor import AnimationAttnProcessor
-from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor
-from animation.modules.face_model import FaceModel
-from animation.modules.id_encoder import FusionFaceId
-from animation.modules.pose_net import PoseNet
-from animation.modules.unet import UNetSpatioTemporalConditionModel
-
-from animation.pipelines.validation_pipeline_animation import ValidationAnimationPipeline
-import transformers
-from accelerate import Accelerator, DistributedType
-from accelerate.logging import get_logger
-from accelerate.utils import ProjectConfiguration, set_seed
-from huggingface_hub import create_repo, upload_folder
-from packaging import version
-from tqdm.auto import tqdm
-from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
-from einops import rearrange
-
-import datetime
-import diffusers
-from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
-from diffusers.image_processor import VaeImageProcessor
-from diffusers.optimization import get_scheduler
-from diffusers.training_utils import EMAModel
-from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
-from diffusers.utils.import_utils import is_xformers_available
-import warnings
-import torch.nn as nn
-from diffusers.utils.torch_utils import randn_tensor
-
-# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0.dev0")
-
-logger = get_logger(__name__, log_level="INFO")
-
-
-# i should make a utility function file
-def validate_and_convert_image(image, target_size=(256, 256)):
- if image is None:
- print("Encountered a None image")
- return None
-
- if isinstance(image, torch.Tensor):
- # Convert PyTorch tensor to PIL Image
- if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
- if image.shape[0] == 1: # Convert single-channel grayscale to RGB
- image = image.repeat(3, 1, 1)
- image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
- image = Image.fromarray(image)
- else:
- print(f"Invalid image tensor shape: {image.shape}")
- return None
- elif isinstance(image, Image.Image):
- # Resize PIL Image
- image = image.resize(target_size)
- else:
- print("Image is not a PIL Image or a PyTorch tensor")
- return None
-
- return image
-
-
-def create_image_grid(images, rows, cols, target_size=(256, 256)):
- valid_images = [validate_and_convert_image(img, target_size) for img in images]
- valid_images = [img for img in valid_images if img is not None]
-
- if not valid_images:
- print("No valid images to create a grid")
- return None
-
- w, h = target_size
- grid = Image.new('RGB', size=(cols * w, rows * h))
-
- for i, image in enumerate(valid_images):
- grid.paste(image, box=((i % cols) * w, (i // cols) * h))
-
- return grid
-
-
-def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder):
- # Flatten batch_output, which is a list of lists of PIL Images
- flattened_batch_output = [img for sublist in batch_output for img in sublist]
-
- # Combine frames into a list without converting (since they are already PIL Images)
- combined_frames = validation_images + validation_control_images + flattened_batch_output
-
- # Calculate rows and columns for the grid
- num_images = len(combined_frames)
- cols = 3 # adjust number of columns as needed
- rows = (num_images + cols - 1) // cols
- timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
-
- filename = f"combined_frames_{timestamp}.png"
- # Create and save the grid image
- grid = create_image_grid(combined_frames, rows, cols)
- output_folder = os.path.join(output_folder, "validation_images")
- os.makedirs(output_folder, exist_ok=True)
-
- # Now define the full path for the file
- timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
- filename = f"combined_frames_{timestamp}.png"
- output_loc = os.path.join(output_folder, filename)
-
- if grid is not None:
- grid.save(output_loc)
- else:
- print("Failed to create image grid")
-
-
-# def load_images_from_folder(folder):
-# images = []
-# valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
-#
-# # Function to extract frame number from the filename
-# def frame_number(filename):
-# # First, try the pattern 'frame_x_7fps'
-# new_pattern_match = re.search(r'frame_(\d+)_7fps', filename)
-# if new_pattern_match:
-# return int(new_pattern_match.group(1))
-# # If the new pattern is not found, use the original digit extraction method
-# matches = re.findall(r'\d+', filename)
-# if matches:
-# if matches[-1] == '0000' and len(matches) > 1:
-# return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
-# return int(matches[-1]) # Otherwise, return the last sequence
-# return float('inf') # Return 'inf'
-#
-# # Sorting files based on frame number
-# sorted_files = sorted(os.listdir(folder), key=frame_number)
-#
-# # Load images in sorted order
-# for filename in sorted_files:
-# ext = os.path.splitext(filename)[1].lower()
-# if ext in valid_extensions:
-# img = Image.open(os.path.join(folder, filename)).convert('RGB')
-# images.append(img)
-#
-# return images
-
-def load_images_from_folder(folder):
- images = []
-
- files = os.listdir(folder)
- png_files = [f for f in files if f.endswith('.png')]
- png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
- for filename in png_files:
- img = Image.open(os.path.join(folder, filename)).convert('RGB')
- images.append(img)
-
- return images
-
-
-# copy from https://github.com/crowsonkb/k-diffusion.git
-def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
- """Draws stratified samples from a uniform distribution."""
- if groups <= 0:
- raise ValueError(f"groups must be positive, got {groups}")
- if group < 0 or group >= groups:
- raise ValueError(f"group must be in [0, {groups})")
- n = shape[-1] * groups
- offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
- u = torch.rand(shape, dtype=dtype, device=device)
- return (offsets + u) / n
-
-
-def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3,
- device='cpu', dtype=torch.float32):
- """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
-
- def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
- t_min = math.atan(math.exp(-0.5 * logsnr_max))
- t_max = math.atan(math.exp(-0.5 * logsnr_min))
- return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
-
- def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
- shift = 2 * math.log(noise_d / image_d)
- return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
-
- def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
- logsnr_low = logsnr_schedule_cosine_shifted(
- t, image_d, noise_d_low, logsnr_min, logsnr_max)
- logsnr_high = logsnr_schedule_cosine_shifted(
- t, image_d, noise_d_high, logsnr_min, logsnr_max)
- return torch.lerp(logsnr_low, logsnr_high, t)
-
- logsnr_min = -2 * math.log(min_value / sigma_data)
- logsnr_max = -2 * math.log(max_value / sigma_data)
- u = stratified_uniform(
- shape, group=0, groups=1, dtype=dtype, device=device
- )
- logsnr = logsnr_schedule_cosine_interpolated(
- u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
- return torch.exp(-logsnr / 2) * sigma_data
-
-
-def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
- """Draws samples from an lognormal distribution."""
- u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
- return torch.distributions.Normal(loc, scale).icdf(u).exp()
-
-
-min_value = 0.002
-max_value = 700
-image_d = 64
-noise_d_low = 32
-noise_d_high = 64
-sigma_data = 0.5
-
-
-def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
- h, w = input.shape[-2:]
- factors = (h / size[0], w / size[1])
-
- # First, we have to determine sigma
- # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
- sigmas = (
- max((factors[0] - 1.0) / 2.0, 0.001),
- max((factors[1] - 1.0) / 2.0, 0.001),
- )
-
- # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
- # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
- # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
- ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
-
- # Make sure it is odd
- if (ks[0] % 2) == 0:
- ks = ks[0] + 1, ks[1]
-
- if (ks[1] % 2) == 0:
- ks = ks[0], ks[1] + 1
-
- input = _gaussian_blur2d(input, ks, sigmas)
-
- output = torch.nn.functional.interpolate(
- input, size=size, mode=interpolation, align_corners=align_corners)
- return output
-
-
-def _compute_padding(kernel_size):
- """Compute padding tuple."""
- # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
- # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
- if len(kernel_size) < 2:
- raise AssertionError(kernel_size)
- computed = [k - 1 for k in kernel_size]
-
- # for even kernels we need to do asymmetric padding :(
- out_padding = 2 * len(kernel_size) * [0]
-
- for i in range(len(kernel_size)):
- computed_tmp = computed[-(i + 1)]
-
- pad_front = computed_tmp // 2
- pad_rear = computed_tmp - pad_front
-
- out_padding[2 * i + 0] = pad_front
- out_padding[2 * i + 1] = pad_rear
-
- return out_padding
-
-
-def _filter2d(input, kernel):
- # prepare kernel
- b, c, h, w = input.shape
- tmp_kernel = kernel[:, None, ...].to(
- device=input.device, dtype=input.dtype)
-
- tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
-
- height, width = tmp_kernel.shape[-2:]
-
- padding_shape: list[int] = _compute_padding([height, width])
- input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
-
- # kernel and input tensor reshape to align element-wise or batch-wise params
- tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
- input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
-
- # convolve the tensor with the kernel.
- output = torch.nn.functional.conv2d(
- input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
-
- out = output.view(b, c, h, w)
- return out
-
-
-def _gaussian(window_size: int, sigma):
- if isinstance(sigma, float):
- sigma = torch.tensor([[sigma]])
-
- batch_size = sigma.shape[0]
-
- x = (torch.arange(window_size, device=sigma.device,
- dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
-
- if window_size % 2 == 0:
- x = x + 0.5
-
- gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
-
- return gauss / gauss.sum(-1, keepdim=True)
-
-
-def _gaussian_blur2d(input, kernel_size, sigma):
- if isinstance(sigma, tuple):
- sigma = torch.tensor([sigma], dtype=input.dtype)
- else:
- sigma = sigma.to(dtype=input.dtype)
-
- ky, kx = int(kernel_size[0]), int(kernel_size[1])
- bs = sigma.shape[0]
- kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
- kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
- out_x = _filter2d(input, kernel_x[..., None, :])
- out = _filter2d(out_x, kernel_y[..., None])
-
- return out
-
-
-def export_to_video(video_frames, output_video_path, fps):
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
- h, w, _ = video_frames[0].shape
- video_writer = cv2.VideoWriter(
- output_video_path, fourcc, fps=fps, frameSize=(w, h))
- for i in range(len(video_frames)):
- img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
- video_writer.write(img)
-
-
-def export_to_gif(frames, output_gif_path, fps):
- """
- Export a list of frames to a GIF.
-
- Args:
- - frames (list): List of frames (as numpy arrays or PIL Image objects).
- - output_gif_path (str): Path to save the output GIF.
- - duration_ms (int): Duration of each frame in milliseconds.
-
- """
- # Convert numpy arrays to PIL Images if needed
- pil_frames = [Image.fromarray(frame) if isinstance(
- frame, np.ndarray) else frame for frame in frames]
-
- pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
- format='GIF',
- append_images=pil_frames[1:],
- save_all=True,
- duration=125,
- loop=0)
-
-
-def tensor_to_vae_latent(t, vae, scale=True):
- t = t.to(vae.dtype)
- if len(t.shape) == 5:
- video_length = t.shape[1]
-
- t = rearrange(t, "b f c h w -> (b f) c h w")
- latents = vae.encode(t).latent_dist.sample()
- latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
- elif len(t.shape) == 4:
- latents = vae.encode(t).latent_dist.sample()
- if scale:
- latents = latents * vae.config.scaling_factor
- return latents
-
-
-def parse_args():
- parser = argparse.ArgumentParser(
- description="Script to train Stable Diffusion XL for InstructPix2Pix."
- )
- parser.add_argument(
- "--pretrained_model_name_or_path",
- type=str,
- default=None,
- required=True,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--revision",
- type=str,
- default=None,
- required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
-
- parser.add_argument(
- "--num_frames",
- type=int,
- default=14,
- )
- parser.add_argument(
- "--dataset_type",
- type=str,
- default='ubc',
- )
- parser.add_argument(
- "--num_validation_images",
- type=int,
- default=1,
- help="Number of images that should be generated during validation with `validation_prompt`.",
- )
- parser.add_argument(
- "--validation_steps",
- type=int,
- default=500,
- help=(
- "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt"
- " multiple times: `args.num_validation_images`."
- ),
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="./outputs",
- help="The output directory where the model predictions and checkpoints will be written.",
- )
- parser.add_argument(
- "--seed", type=int, default=None, help="A seed for reproducible training."
- )
- parser.add_argument(
- "--per_gpu_batch_size",
- type=int,
- default=1,
- help="Batch size (per device) for the training dataloader.",
- )
- parser.add_argument("--num_train_epochs", type=int, default=100)
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=None,
- help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
- )
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
- parser.add_argument(
- "--gradient_checkpointing",
- action="store_true",
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=1e-4,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument(
- "--scale_lr",
- action="store_true",
- default=False,
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
- )
- parser.add_argument(
- "--lr_scheduler",
- type=str,
- default="constant",
- help=(
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
- ' "constant", "constant_with_warmup"]'
- ),
- )
- parser.add_argument(
- "--lr_warmup_steps",
- type=int,
- default=500,
- help="Number of steps for the warmup in the lr scheduler.",
- )
- parser.add_argument(
- "--conditioning_dropout_prob",
- type=float,
- default=0.1,
- help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
- )
- parser.add_argument(
- "--use_8bit_adam",
- action="store_true",
- help="Whether or not to use 8-bit Adam from bitsandbytes.",
- )
- parser.add_argument(
- "--allow_tf32",
- action="store_true",
- help=(
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
- ),
- )
- parser.add_argument(
- "--use_ema", action="store_true", help="Whether to use EMA model."
- )
- parser.add_argument(
- "--non_ema_revision",
- type=str,
- default=None,
- required=False,
- help=(
- "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
- " remote repository specified with --pretrained_model_name_or_path."
- ),
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=8,
- help=(
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
- ),
- )
- parser.add_argument(
- "--adam_beta1",
- type=float,
- default=0.9,
- help="The beta1 parameter for the Adam optimizer.",
- )
- parser.add_argument(
- "--adam_beta2",
- type=float,
- default=0.999,
- help="The beta2 parameter for the Adam optimizer.",
- )
- parser.add_argument(
- "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
- )
- parser.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer",
- )
- parser.add_argument(
- "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
- )
- parser.add_argument(
- "--push_to_hub",
- action="store_true",
- help="Whether or not to push the model to the Hub.",
- )
- parser.add_argument(
- "--hub_token",
- type=str,
- default=None,
- help="The token to use to push to the Model Hub.",
- )
- parser.add_argument(
- "--hub_model_id",
- type=str,
- default=None,
- help="The name of the repository to keep in sync with the local `output_dir`.",
- )
- parser.add_argument(
- "--logging_dir",
- type=str,
- default="logs",
- help=(
- "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
- " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
- ),
- )
- parser.add_argument(
- "--mixed_precision",
- type=str,
- default=None,
- choices=["no", "fp16", "bf16"],
- help=(
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
- ),
- )
- parser.add_argument(
- "--report_to",
- type=str,
- default="tensorboard",
- help=(
- 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
- ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
- ),
- )
- parser.add_argument(
- "--local_rank",
- type=int,
- default=-1,
- help="For distributed training: local_rank",
- )
- parser.add_argument(
- "--checkpointing_steps",
- type=int,
- default=500,
- help=(
- "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
- " training using `--resume_from_checkpoint`."
- ),
- )
- parser.add_argument(
- "--checkpoints_total_limit",
- type=int,
- default=1,
- help=("Max number of checkpoints to store."),
- )
- parser.add_argument(
- "--resume_from_checkpoint",
- type=str,
- default=None,
- help=(
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
- ),
- )
- parser.add_argument(
- "--enable_xformers_memory_efficient_attention",
- action="store_true",
- help="Whether or not to use xformers.",
- )
- parser.add_argument(
- "--log_trainable_parameters",
- action="store_true",
- help="Whether to write the trainable parameters.",
- )
- parser.add_argument(
- "--pretrain_unet",
- type=str,
- default=None,
- help="use weight for unet block",
- )
- parser.add_argument(
- "--rank",
- type=int,
- default=128,
- help=("The dimension of the LoRA update matrices."),
- )
- parser.add_argument(
- "--csv_path",
- type=str,
- default=None,
- help=(
- "path to the dataset csv"
- ),
- )
- parser.add_argument(
- "--video_folder",
- type=str,
- default=None,
- help=(
- "path to the video folder"
- ),
- )
- parser.add_argument(
- "--condition_folder",
- type=str,
- default=None,
- help=(
- "path to the depth folder"
- ),
- )
- parser.add_argument(
- "--motion_folder",
- type=str,
- default=None,
- help=(
- "path to the depth folder"
- ),
- )
- parser.add_argument(
- "--validation_prompt",
- type=str,
- default=None,
- help=(
- "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
- " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
- " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
- ),
- )
- parser.add_argument(
- "--validation_image_folder",
- type=str,
- default=None,
- help=(
- "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
- " `--validation_image` that will be used with all `--validation_prompt`s."
- ),
- )
- parser.add_argument(
- "--validation_image",
- type=str,
- default=None,
- help=(
- "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
- " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
- " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
- " `--validation_image` that will be used with all `--validation_prompt`s."
- ),
- )
- parser.add_argument(
- "--validation_control_folder",
- type=str,
- default=None,
- help=(
- "the validation control image"
- ),
- )
- parser.add_argument(
- "--sample_n_frames",
- type=int,
- default=14,
- help=(
- "the sample_n_frames"
- ),
- )
-
- parser.add_argument(
- "--ref_augment",
- action="store_true",
- help=(
- "use augment for the reference image"
- ),
- )
- parser.add_argument(
- "--train_stage",
- type=int,
- default=2,
- help=(
- "the training stage"
- ),
- )
-
- parser.add_argument(
- "--posenet_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained posenet model",
- )
- parser.add_argument(
- "--face_encoder_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained face encoder model",
- )
- parser.add_argument(
- "--unet_model_name_or_path",
- type=str,
- default=None,
- help="Path to pretrained unet model",
- )
-
- parser.add_argument(
- "--data_root_path",
- type=str,
- default=None,
- help="Path to the data root path",
- )
- parser.add_argument(
- "--data_path",
- type=str,
- default=None,
- help="Path to the data path",
- )
-
- parser.add_argument(
- "--finetune_mode",
- type=bool,
- default=False,
- help="Enable or disable the finetune mode (True/False).",
- )
- parser.add_argument(
- "--posenet_model_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained posenet model",
- )
- parser.add_argument(
- "--face_encoder_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained face encoder",
- )
- parser.add_argument(
- "--unet_model_finetune_path",
- type=str,
- default=None,
- help="Path to the pretrained unet model",
- )
-
- parser.add_argument(
- "--dataset_width",
- type=int,
- default=512,
- help="video dataset width",
- )
- parser.add_argument(
- "--dataset_height",
- type=int,
- default=512,
- help="video dataset height",
- )
-
- args = parser.parse_args()
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
- if env_local_rank != -1 and env_local_rank != args.local_rank:
- args.local_rank = env_local_rank
-
- # default to using the same revision for the non-ema model if not specified
- if args.non_ema_revision is None:
- args.non_ema_revision = args.revision
-
- return args
-
-
-def download_image(url):
- original_image = (
- lambda image_url_or_path: load_image(image_url_or_path)
- if urlparse(image_url_or_path).scheme
- else PIL.Image.open(image_url_or_path).convert("RGB")
- )(url)
- return original_image
-
-
-# This is for training using deepspeed.
-# Since now the DeepSpeed only supports trainging with only one model
-# So we create a virtual wrapper to contail all the models
-
-class DeepSpeedWrapperModel(nn.Module):
- def __init__(self, **kwargs):
- super().__init__()
- for name, value in kwargs.items():
- assert isinstance(value, nn.Module)
- self.register_module(name, value)
-
-
-def main():
- warnings.filterwarnings('ignore', category=DeprecationWarning)
- warnings.filterwarnings('ignore', category=FutureWarning)
- torch.multiprocessing.set_start_method('spawn')
-
- args = parse_args()
-
- if args.non_ema_revision is not None:
- deprecate(
- "non_ema_revision!=None",
- "0.15.0",
- message=(
- "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
- " use `--variant=non_ema` instead."
- ),
- )
- logging_dir = os.path.join(args.output_dir, args.logging_dir)
- accelerator_project_config = ProjectConfiguration(
- project_dir=args.output_dir, logging_dir=logging_dir)
- # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
- accelerator = Accelerator(
- gradient_accumulation_steps=args.gradient_accumulation_steps,
- mixed_precision=args.mixed_precision,
- project_config=accelerator_project_config,
- )
-
- generator = torch.Generator(
- device=accelerator.device).manual_seed(23123134)
-
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError(
- "Make sure to install wandb if you want to use it for logging during training.")
- import wandb
-
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- logger.info(accelerator.state, main_process_only=False)
- if accelerator.is_local_main_process:
- transformers.utils.logging.set_verbosity_warning()
- diffusers.utils.logging.set_verbosity_info()
- else:
- transformers.utils.logging.set_verbosity_error()
- diffusers.utils.logging.set_verbosity_error()
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- set_seed(args.seed)
-
- # Handle the repository creation
- if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
- if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
-
- # Load scheduler, tokenizer and models.
- print(args.pretrained_model_name_or_path)
- feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path,
- subfolder="feature_extractor", revision=args.revision)
- noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision
- )
- vae = AutoencoderKLTemporalDecoder.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16")
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
- args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet,
- subfolder="unet",
- low_cpu_mem_usage=True,
- variant="fp16"
- )
- pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0])
- face_encoder = FusionFaceId(
- cross_attention_dim=1024,
- id_embeddings_dim=512,
- clip_embeddings_dim=1024,
- num_tokens=4, )
- face_model = FaceModel()
-
- # init adapter modules
- lora_rank = 128
- attn_procs = {}
- unet_svd = unet.state_dict()
-
- for name in unet.attn_processors.keys():
- if "transformer_blocks" in name and "temporal_transformer_blocks" not in name:
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- # print(f"This is AnimationAttnProcessor: {name}")
- attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim, rank=lora_rank)
- else:
- # print(f"This is AnimationIDAttnNormalizedProcessor: {name}")
- layer_name = name.split(".processor")[0]
- weights = {
- "id_to_k.weight": unet_svd[layer_name + ".to_k.weight"],
- "id_to_v.weight": unet_svd[layer_name + ".to_v.weight"],
- }
- attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim,
- rank=lora_rank)
- attn_procs[name].load_state_dict(weights, strict=False)
- elif "temporal_transformer_blocks" in name:
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
- if cross_attention_dim is None:
- attn_procs[name] = XFormersAttnProcessor()
- else:
- attn_procs[name] = XFormersAttnProcessor()
- unet.set_attn_processor(attn_procs)
-
- # triggering the finetune mode
- if args.finetune_mode is True and args.posenet_model_finetune_path is not None and args.face_encoder_finetune_path is not None and args.unet_model_finetune_path is not None:
- print("Loading existing posenet weights, face_encoder weights and unet weights.")
- if args.posenet_model_finetune_path.endswith(".pth"):
- pose_net_state_dict = torch.load(args.posenet_model_finetune_path, map_location="cpu")
- pose_net.load_state_dict(pose_net_state_dict, strict=True)
- else:
- print("posenet weights loading fail")
- print(1 / 0)
- if args.face_encoder_finetune_path.endswith(".pth"):
- face_encoder_state_dict = torch.load(args.face_encoder_finetune_path, map_location="cpu")
- face_encoder.load_state_dict(face_encoder_state_dict, strict=True)
- else:
- print("face_encoder weights loading fail")
- print(1 / 0)
- if args.unet_model_finetune_path.endswith(".pth"):
- unet_state_dict = torch.load(args.unet_model_finetune_path, map_location="cpu")
- unet.load_state_dict(unet_state_dict, strict=True)
- else:
- print("unet weights loading fail")
- print(1 / 0)
-
- vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
-
- # Freeze vae and image_encoder
- vae.requires_grad_(False)
- image_encoder.requires_grad_(False)
- unet.requires_grad_(False)
- pose_net.requires_grad_(False)
- face_encoder.requires_grad_(False)
-
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- image_encoder.to(accelerator.device, dtype=weight_dtype)
- vae.to(accelerator.device, dtype=weight_dtype)
-
- if args.use_ema:
- ema_unet = EMAModel(unet.parameters(
- ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config)
-
- if args.enable_xformers_memory_efficient_attention:
- if is_xformers_available():
- import xformers
- xformers_version = version.parse(xformers.__version__)
- if xformers_version == version.parse("0.0.16"):
- logger.warn(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
- )
- unet.enable_xformers_memory_efficient_attention()
- else:
- raise ValueError(
- "xformers is not available. Make sure it is installed correctly")
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
-
- # Enable TF32 for faster training on Ampere GPUs,
- # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if args.allow_tf32:
- torch.backends.cuda.matmul.allow_tf32 = True
-
- if args.scale_lr:
- args.learning_rate = (
- args.learning_rate * args.gradient_accumulation_steps *
- args.per_gpu_batch_size * accelerator.num_processes
- )
-
- # Initialize the optimizer
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
- )
-
- optimizer_cls = bnb.optim.AdamW8bit
- else:
- optimizer_cls = torch.optim.AdamW
-
- # if accelerator.distributed_type == DistributedType.DEEPSPEED:
- # ds_wrapper = DeepSpeedWrapperModel(
- # unet=unet,
- # controlnext=controlnext
- # )
- # unet = ds_wrapper.unet
- # controlnext = ds_wrapper.controlnext
-
- pose_net.requires_grad_(True)
- face_encoder.requires_grad_(True)
-
- parameters_list = []
-
- for name, para in pose_net.named_parameters():
- para.requires_grad = True
- parameters_list.append({"params": para, "lr": args.learning_rate})
-
- for name, para in face_encoder.named_parameters():
- para.requires_grad = True
- parameters_list.append({"params": para, "lr": args.learning_rate})
-
- """
- For more details, please refer to: https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333
- This is the selective parameters part.
- As presented in our paper, we only select a small subset of parameters, which is fully adapted to the SD1.5 and SDXL backbones. By training fewer than 100 million parameters, we still achieve excellent performance. But this is is not suitable for the SD3 and SVD training. This is because, after SDXL, Stability faced significant legal risks due to the generation of highly realistic human images. After that, they stopped refining their models on human-related data, such as SVD and SD3, to avoid potential risks.
- To achieve optimal performance, it's necessary to first continue training SVD and SD3 on human-related data to develop a robust backbone before fine-tuning. Of course, you can also combine the continual pretraining and finetuning. So you can find that we direct provide the full SVD parameters.
- We have experimented with two approaches: 1.Directly training the model from scratch on human dancing data. 2. Continual training using a pre-trained human generation backbone, followed by fine-tuning a selective small subset of parameters. Interestingly, we observed no significant difference in performance between these two methods.
- """
-
- for name, para in unet.named_parameters():
- if "attentions" in name:
- para.requires_grad = True
- parameters_list.append({"params": para})
- else:
- para.requires_grad = False
-
- optimizer = optimizer_cls(
- parameters_list,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
-
- # check para
- if accelerator.is_main_process and args.log_trainable_parameters:
- rec_txt1 = open('rec_para.txt', 'w')
- rec_txt2 = open('rec_para_train.txt', 'w')
- for name, para in unet.named_parameters():
- if para.requires_grad is False:
- rec_txt1.write(f'{name}\n')
- else:
- rec_txt2.write(f'{name}\n')
- rec_txt1.close()
- rec_txt2.close()
- # DataLoaders creation:
- args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
-
- root_path = args.data_root_path
- txt_path = args.data_path
- train_dataset = LargeScaleAnimationVideos(
- root_path=root_path,
- txt_path=txt_path,
- width=args.dataset_width,
- height=args.dataset_height,
- n_sample_frames=args.sample_n_frames,
- sample_frame_rate=4,
- app=face_model.app,
- handler_ante=face_model.handler_ante,
- face_helper=face_model.face_helper
- )
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=args.per_gpu_batch_size,
- num_workers=args.num_workers,
- shuffle=True,
- )
-
- # Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
-
- lr_scheduler = get_scheduler(
- args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
- )
-
- unet, pose_net, face_encoder, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
- unet, pose_net, face_encoder, optimizer, lr_scheduler, train_dataloader
- )
-
- if args.use_ema:
- ema_unet.to(accelerator.device)
-
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- # Afterwards we recalculate our number of training epochs
- args.num_train_epochs = math.ceil(
- args.max_train_steps / num_update_steps_per_epoch)
-
- # We need to initialize the trackers we use, and also store our configuration.
- # The trackers initializes automatically on the main process.
- if accelerator.is_main_process:
- accelerator.init_trackers("StableAnimator", config=vars(args))
-
- # Train!
- total_batch_size = args.per_gpu_batch_size * \
- accelerator.num_processes * args.gradient_accumulation_steps
-
- logger.info("***** Running training *****")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num Epochs = {args.num_train_epochs}")
- logger.info(
- f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
- logger.info(
- f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
- logger.info(
- f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
- logger.info(f" Total optimization steps = {args.max_train_steps}")
- global_step = 0
- first_epoch = 0
-
- def encode_image(pixel_values):
- pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
- pixel_values = (pixel_values + 1.0) / 2.0
-
- pixel_values = pixel_values.to(torch.float32)
- # Normalize the image with for CLIP input
- pixel_values = feature_extractor(
- images=pixel_values,
- do_normalize=True,
- do_center_crop=False,
- do_resize=False,
- do_rescale=False,
- return_tensors="pt",
- ).pixel_values
-
- pixel_values = pixel_values.to(
- device=accelerator.device, dtype=image_encoder.dtype)
- image_embeddings = image_encoder(pixel_values).image_embeds
- image_embeddings = image_embeddings.unsqueeze(1)
- return image_embeddings
-
- def _get_add_time_ids(
- fps,
- motion_bucket_id,
- noise_aug_strength,
- dtype,
- batch_size,
- unet=None,
- device=None
- ):
- add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
-
- add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=device)
- add_time_ids = add_time_ids.repeat(batch_size, 1)
- return add_time_ids
-
- # Potentially load in the weights and states from a previous save
- if args.resume_from_checkpoint:
- if args.resume_from_checkpoint != "latest":
- path = os.path.basename(args.resume_from_checkpoint)
- else:
- # Get the most recent checkpoint
- dirs = os.listdir(args.output_dir)
- dirs = [d for d in dirs if d.startswith("checkpoint")]
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
- path = dirs[-1] if len(dirs) > 0 else None
-
- if path is None:
- accelerator.print(
- f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
- )
- args.resume_from_checkpoint = None
- else:
- accelerator.print(f"Resuming from checkpoint {path}")
- accelerator.load_state(os.path.join(args.output_dir, path))
- global_step = int(path.split("-")[1])
-
- resume_global_step = global_step * args.gradient_accumulation_steps
- first_epoch = global_step // num_update_steps_per_epoch
- resume_step = resume_global_step % (
- num_update_steps_per_epoch * args.gradient_accumulation_steps)
-
- # Only show the progress bar once on each machine.
- progress_bar = tqdm(range(global_step, args.max_train_steps),
- disable=not accelerator.is_local_main_process)
- progress_bar.set_description("Steps")
-
- for epoch in range(first_epoch, args.num_train_epochs):
- pose_net.train()
- face_encoder.train()
- unet.train()
- train_loss = 0.0
-
- for step, batch in enumerate(train_dataloader):
- # Skip steps until we reach the resumed step
- if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
- if step % args.gradient_accumulation_steps == 0:
- progress_bar.update(1)
- continue
-
- with accelerator.accumulate(pose_net, face_encoder, unet):
- with accelerator.autocast():
- pixel_values = batch["pixel_values"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
- conditional_pixel_values = batch["reference_image"].to(weight_dtype).to(
- accelerator.device, non_blocking=True
- )
-
- latents = tensor_to_vae_latent(pixel_values, vae).to(dtype=weight_dtype)
-
- # Get the text embedding for conditioning.
- encoder_hidden_states = encode_image(conditional_pixel_values).to(dtype=weight_dtype)
- image_embed = encoder_hidden_states.clone()
-
- train_noise_aug = 0.02
- conditional_pixel_values = conditional_pixel_values + train_noise_aug * \
- randn_tensor(conditional_pixel_values.shape, generator=generator,
- device=conditional_pixel_values.device,
- dtype=conditional_pixel_values.dtype)
- conditional_latents = tensor_to_vae_latent(conditional_pixel_values, vae, scale=False)
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents)
- bsz = latents.shape[0]
- # Sample a random timestep for each image
- sigmas = rand_cosine_interpolated(shape=[bsz, ], image_d=image_d, noise_d_low=noise_d_low,
- noise_d_high=noise_d_high, sigma_data=sigma_data,
- min_value=min_value, max_value=max_value).to(latents.device,
- dtype=weight_dtype)
-
- # sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents)
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- sigmas_reshaped = sigmas.clone()
- while len(sigmas_reshaped.shape) < len(latents.shape):
- sigmas_reshaped = sigmas_reshaped.unsqueeze(-1)
-
- noisy_latents = latents + noise * sigmas_reshaped
-
- timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(latents.device,
- dtype=weight_dtype)
-
- inp_noisy_latents = noisy_latents / ((sigmas_reshaped ** 2 + 1) ** 0.5)
-
- added_time_ids = _get_add_time_ids(
- fps=6,
- motion_bucket_id=127.0,
- noise_aug_strength=train_noise_aug, # noise_aug_strength == 0.0
- dtype=encoder_hidden_states.dtype,
- batch_size=bsz,
- unet=unet,
- device=latents.device
- )
-
- added_time_ids = added_time_ids.to(latents.device)
-
- # Conditioning dropout to support classifier-free guidance during inference. For more details
- # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
- if args.conditioning_dropout_prob is not None:
- random_p = torch.rand(
- bsz, device=latents.device, generator=generator)
- # Sample masks for the edit prompts.
- prompt_mask = random_p < 2 * args.conditioning_dropout_prob
- prompt_mask = prompt_mask.reshape(bsz, 1, 1)
- # Final text conditioning.
- null_conditioning = torch.zeros_like(encoder_hidden_states)
- encoder_hidden_states = torch.where(
- prompt_mask, null_conditioning, encoder_hidden_states)
-
- # Sample masks for the original images.
- image_mask_dtype = conditional_latents.dtype
- image_mask = 1 - (
- (random_p >= args.conditioning_dropout_prob).to(
- image_mask_dtype)
- * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
- )
- image_mask = image_mask.reshape(bsz, 1, 1, 1)
- # Final image conditioning.
- conditional_latents = image_mask * conditional_latents
-
- # Concatenate the `conditional_latents` with the `noisy_latents`.
- conditional_latents = conditional_latents.unsqueeze(
- 1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
-
- pose_pixels = batch["pose_pixels"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- faceid_embeds = batch["faceid_embeds"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- pose_latents = pose_net(pose_pixels)
-
- # print("This is faceid_latents calculation")
- # print(faceid_embeds.size()) # [1, 512]
- # print(image_embed.size()) # [1, 1, 1024]
-
- faceid_latents = face_encoder(faceid_embeds, image_embed)
-
- inp_noisy_latents = torch.cat(
- [inp_noisy_latents, conditional_latents], dim=2)
- target = latents
-
- # print(f"the size of encoder_hidden_states: {encoder_hidden_states.size()}") # [1, 1, 1024]
- # print(f"the size of face latents: {faceid_latents.size()}") # [1, 4, 1024]
- encoder_hidden_states = torch.cat([encoder_hidden_states, faceid_latents], dim=1)
-
- encoder_hidden_states = encoder_hidden_states.to(latents.dtype)
- inp_noisy_latents = inp_noisy_latents.to(latents.dtype)
- pose_latents = pose_latents.to(latents.dtype)
-
- # Predict the noise residual
- model_pred = unet(
- inp_noisy_latents, timesteps, encoder_hidden_states,
- added_time_ids=added_time_ids,
- pose_latents=pose_latents,
- ).sample
-
- sigmas = sigmas_reshaped
- # Denoise the latents
- c_out = -sigmas / ((sigmas ** 2 + 1) ** 0.5)
- c_skip = 1 / (sigmas ** 2 + 1)
- denoised_latents = model_pred * c_out + c_skip * noisy_latents
- weighing = (1 + sigmas ** 2) * (sigmas ** -2.0)
-
- tgt_face_masks = batch["tgt_face_masks"].to(
- dtype=weight_dtype, device=accelerator.device, non_blocking=True
- )
- tgt_face_masks = rearrange(tgt_face_masks, "b f c h w -> (b f) c h w")
- tgt_face_masks = F.interpolate(tgt_face_masks, size=(target.size()[-2], target.size()[-1]),
- mode='nearest')
- tgt_face_masks = rearrange(tgt_face_masks, "(b f) c h w -> b f c h w", f=args.sample_n_frames)
-
- # MSE loss
- loss = torch.mean(
- (weighing.float() * (denoised_latents.float() -
- target.float()) ** 2 * (1 + tgt_face_masks)).reshape(target.shape[0], -1),
- dim=1,
- )
- loss = loss.mean()
-
- # Gather the losses across all processes for logging (if we use distributed training).
- avg_loss = accelerator.gather(
- loss.repeat(args.per_gpu_batch_size)).mean()
- train_loss += avg_loss.item() / args.gradient_accumulation_steps
-
- # Backpropagate
- accelerator.backward(loss)
- # if accelerator.sync_gradients:
- # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- with torch.cuda.device(latents.device):
- torch.cuda.empty_cache()
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- if args.use_ema:
- ema_unet.step(unet.parameters())
- progress_bar.update(1)
- global_step += 1
- accelerator.log({"train_loss": train_loss}, step=global_step)
- train_loss = 0.0
-
- # save checkpoints!
- # if global_step % args.checkpointing_steps == 0 and (accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED):
- if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process:
-
- # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
- if args.checkpoints_total_limit is not None and accelerator.is_main_process:
- checkpoints = os.listdir(args.output_dir)
- checkpoints = [
- d for d in checkpoints if d.startswith("checkpoint")]
- checkpoints = sorted(
- checkpoints, key=lambda x: int(x.split("-")[1]))
-
- # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
- if len(checkpoints) >= args.checkpoints_total_limit:
- num_to_remove = len(
- checkpoints) - args.checkpoints_total_limit + 1
- removing_checkpoints = checkpoints[0:num_to_remove]
-
- logger.info(
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
- )
- logger.info(
- f"removing checkpoints: {', '.join(removing_checkpoints)}")
-
- for removing_checkpoint in removing_checkpoints:
- removing_checkpoint = os.path.join(
- args.output_dir, removing_checkpoint)
- shutil.rmtree(removing_checkpoint)
-
- save_path = os.path.join(
- args.output_dir, f"checkpoint-{global_step}")
- accelerator.save_state(save_path)
- unwrap_unet = accelerator.unwrap_model(unet)
- unwrap_pose_net = accelerator.unwrap_model(pose_net)
- unwrap_face_encoder = accelerator.unwrap_model(face_encoder)
- unwrap_unet_state_dict = unwrap_unet.state_dict()
- torch.save(unwrap_unet_state_dict,
- os.path.join(args.output_dir, f"checkpoint-{global_step}", f"unet-{global_step}.pth"))
- unwrap_pose_net_state_dict = unwrap_pose_net.state_dict()
- torch.save(unwrap_pose_net_state_dict, os.path.join(args.output_dir, f"checkpoint-{global_step}",
- f"pose_net-{global_step}.pth"))
- unwrap_face_encoder_state_dict = unwrap_face_encoder.state_dict()
- torch.save(unwrap_face_encoder_state_dict,
- os.path.join(args.output_dir, f"checkpoint-{global_step}",
- f"face_encoder-{global_step}.pth"))
- logger.info(f"Saved state to {save_path}")
-
- if accelerator.is_main_process:
- # sample images!
- if global_step % args.validation_steps == 0:
- logger.info(
- f"Running validation... \n Generating {args.num_validation_images} videos."
- )
- # create pipeline
- if args.use_ema:
- # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
- ema_unet.store(unet.parameters())
- ema_unet.copy_to(unet.parameters())
-
- log_validation(
- vae=vae,
- image_encoder=image_encoder,
- unet=unet,
- pose_net=pose_net,
- face_encoder=face_encoder,
- app=face_model.app,
- face_helper=face_model.face_helper,
- handler_ante=face_model.handler_ante,
- scheduler=noise_scheduler,
- accelerator=accelerator,
- feature_extractor=feature_extractor,
- width=512,
- height=512,
- torch_dtype=weight_dtype,
- validation_image_folder=args.validation_image_folder,
- validation_image=args.validation_image,
- validation_control_folder=args.validation_control_folder,
- output_dir=args.output_dir,
- generator=generator,
- global_step=global_step,
- num_validation_cases=1,
- )
-
- if args.use_ema:
- # Switch back to the original UNet parameters.
- ema_unet.restore(unet.parameters())
-
- with torch.cuda.device(latents.device):
- torch.cuda.empty_cache()
-
- logs = {"step_loss": loss.detach().item(
- ), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- # save checkpoints!
- # if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
- if accelerator.is_main_process:
- save_path = os.path.join(
- args.output_dir, f"checkpoint-last")
- accelerator.save_state(save_path)
- logger.info(f"Saved state to {save_path}")
-
-
-def log_validation(
- vae,
- image_encoder,
- unet,
- pose_net,
- face_encoder,
- app,
- face_helper,
- handler_ante,
- scheduler,
- accelerator,
- feature_extractor,
- width,
- height,
- torch_dtype,
- validation_image_folder,
- validation_image,
- validation_control_folder,
- output_dir,
- generator,
- global_step,
- num_validation_cases=1,
-):
- logger.info("Running validation... ")
- validation_unet = accelerator.unwrap_model(unet)
- validation_image_encoder = accelerator.unwrap_model(image_encoder)
- validation_vae = accelerator.unwrap_model(vae)
- validation_pose_net = accelerator.unwrap_model(pose_net)
- validation_face_encoder = accelerator.unwrap_model(face_encoder)
-
- pipeline = ValidationAnimationPipeline(
- vae=validation_vae,
- image_encoder=validation_image_encoder,
- unet=validation_unet,
- scheduler=scheduler,
- feature_extractor=feature_extractor,
- pose_net=validation_pose_net,
- face_encoder=validation_face_encoder,
- )
- pipeline = pipeline.to(accelerator.device)
- validation_images = load_images_from_folder(validation_image_folder)
- validation_image_path = validation_image
- if validation_image is None:
- validation_image = validation_images[0]
- else:
- validation_image = Image.open(validation_image).convert('RGB')
- validation_control_images = load_images_from_folder(validation_control_folder)
-
- val_save_dir = os.path.join(output_dir, "validation_images")
- if not os.path.exists(val_save_dir):
- os.makedirs(val_save_dir)
-
- with accelerator.autocast():
- for val_img_idx in range(num_validation_cases):
- # num_frames = args.num_frames
- num_frames = len(validation_control_images)
-
- face_helper.clean_all()
- validation_face = cv2.imread(validation_image_path)
- # validation_image_bgr = cv2.cvtColor(validation_face, cv2.COLOR_RGB2BGR)
- validation_image_face_info = app.get(validation_face)
- if len(validation_image_face_info) > 0:
- validation_image_face_info = sorted(validation_image_face_info,
- key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (
- x['bbox'][3] - x['bbox'][1]))[-1]
- validation_image_id_ante_embedding = validation_image_face_info['embedding']
- else:
- validation_image_id_ante_embedding = None
-
- if validation_image_id_ante_embedding is None:
- face_helper.read_image(validation_face)
- face_helper.get_face_landmarks_5(only_center_face=True)
- face_helper.align_warp_face()
-
- if len(face_helper.cropped_faces) == 0:
- validation_image_id_ante_embedding = np.zeros((512,))
- else:
- validation_image_align_face = face_helper.cropped_faces[0]
- print('fail to detect face using insightface, extract embedding on align face')
- validation_image_id_ante_embedding = handler_ante.get_feat(validation_image_align_face)
-
- video_frames = pipeline(
- image=validation_image,
- image_pose=validation_control_images,
- height=height,
- width=width,
- num_frames=num_frames,
- tile_size=num_frames,
- tile_overlap=4,
- decode_chunk_size=4,
- motion_bucket_id=127.,
- fps=7,
- min_guidance_scale=3,
- max_guidance_scale=3,
- noise_aug_strength=0.02,
- num_inference_steps=25,
- generator=generator,
- output_type="pil",
- validation_image_id_ante_embedding=validation_image_id_ante_embedding,
- ).frames[0]
- # save_combined_frames(video_frames, validation_images, validation_control_images, val_save_dir)
-
- out_file = os.path.join(
- val_save_dir,
- f"step_{global_step}_val_img_{val_img_idx}.mp4",
- )
- # print(video_frames.size()) # [16, 3, 512, 512]
- for i in range(num_frames):
- img = video_frames[i]
- video_frames[i] = np.array(img)
- export_to_gif(video_frames, out_file, 8)
-
- del pipeline
- torch.cuda.empty_cache()
-
-
-if __name__ == "__main__":
- main()
+from torch.utils.data._utils.collate import default_collate
+
+def safe_collate(batch):
+ good = []
+ for s in batch:
+ if s is None:
+ continue
+ good.append(s)
+ if len(good) == 0:
+ return None
+ return default_collate(good)
+
+
+import argparse
+import random
+import logging
+import math
+import os
+
+import cv2
+import shutil
+from pathlib import Path
+from urllib.parse import urlparse
+import numpy as np
+import PIL
+from PIL import Image, ImageDraw
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from diffusers.models.attention_processor import XFormersAttnProcessor
+from animation.dataset.animation_new_dataset import LargeScaleMusicVideos
+from animation.modules.attention_processor import AnimationAttnProcessor
+from animation.modules.attention_processor_normalized import AnimationIDAttnNormalizedProcessor
+from animation.modules.face_model import FaceModel
+from animation.modules.id_encoder import FusionFaceId
+from animation.modules.pose_net import PoseNet
+from animation.modules.unet import UNetSpatioTemporalConditionModel
+from animation.modules.music_encoder import MusicEncoder
+
+import transformers
+from accelerate import Accelerator, DistributedType
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from tqdm.auto import tqdm
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from einops import rearrange
+
+import datetime
+import diffusers
+from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
+from diffusers.utils.import_utils import is_xformers_available
+import warnings
+import torch.nn as nn
+from diffusers.utils.torch_utils import randn_tensor
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.24.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+# i should make a utility function file
+def validate_and_convert_image(image, target_size=(256, 256)):
+ if image is None:
+ print("Encountered a None image")
+ return None
+
+ if isinstance(image, torch.Tensor):
+ # Convert PyTorch tensor to PIL Image
+ if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
+ if image.shape[0] == 1: # Convert single-channel grayscale to RGB
+ image = image.repeat(3, 1, 1)
+ image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
+ image = Image.fromarray(image)
+ else:
+ print(f"Invalid image tensor shape: {image.shape}")
+ return None
+ elif isinstance(image, Image.Image):
+ # Resize PIL Image
+ image = image.resize(target_size)
+ else:
+ print("Image is not a PIL Image or a PyTorch tensor")
+ return None
+
+ return image
+
+
+def create_image_grid(images, rows, cols, target_size=(256, 256)):
+ valid_images = [validate_and_convert_image(img, target_size) for img in images]
+ valid_images = [img for img in valid_images if img is not None]
+
+ if not valid_images:
+ print("No valid images to create a grid")
+ return None
+
+ w, h = target_size
+ grid = Image.new('RGB', size=(cols * w, rows * h))
+
+ for i, image in enumerate(valid_images):
+ grid.paste(image, box=((i % cols) * w, (i // cols) * h))
+
+ return grid
+
+
+def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder):
+ # Flatten batch_output, which is a list of lists of PIL Images
+ flattened_batch_output = [img for sublist in batch_output for img in sublist]
+
+ # Combine frames into a list without converting (since they are already PIL Images)
+ combined_frames = validation_images + validation_control_images + flattened_batch_output
+
+ # Calculate rows and columns for the grid
+ num_images = len(combined_frames)
+ cols = 3 # adjust number of columns as needed
+ rows = (num_images + cols - 1) // cols
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+
+ filename = f"combined_frames_{timestamp}.png"
+ # Create and save the grid image
+ grid = create_image_grid(combined_frames, rows, cols)
+ output_folder = os.path.join(output_folder, "validation_images")
+ os.makedirs(output_folder, exist_ok=True)
+
+ # Now define the full path for the file
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ filename = f"combined_frames_{timestamp}.png"
+ output_loc = os.path.join(output_folder, filename)
+
+ if grid is not None:
+ grid.save(output_loc)
+ else:
+ print("Failed to create image grid")
+
+
+# def load_images_from_folder(folder):
+# images = []
+# valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
+#
+# # Function to extract frame number from the filename
+# def frame_number(filename):
+# # First, try the pattern 'frame_x_7fps'
+# new_pattern_match = re.search(r'frame_(\d+)_7fps', filename)
+# if new_pattern_match:
+# return int(new_pattern_match.group(1))
+# # If the new pattern is not found, use the original digit extraction method
+# matches = re.findall(r'\d+', filename)
+# if matches:
+# if matches[-1] == '0000' and len(matches) > 1:
+# return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
+# return int(matches[-1]) # Otherwise, return the last sequence
+# return float('inf') # Return 'inf'
+#
+# # Sorting files based on frame number
+# sorted_files = sorted(os.listdir(folder), key=frame_number)
+#
+# # Load images in sorted order
+# for filename in sorted_files:
+# ext = os.path.splitext(filename)[1].lower()
+# if ext in valid_extensions:
+# img = Image.open(os.path.join(folder, filename)).convert('RGB')
+# images.append(img)
+#
+# return images
+
+def load_images_from_folder(folder):
+ images = []
+
+ files = os.listdir(folder)
+ png_files = [f for f in files if f.endswith('.png')]
+ png_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
+ for filename in png_files:
+ img = Image.open(os.path.join(folder, filename)).convert('RGB')
+ images.append(img)
+
+ return images
+
+
+# copy from https://github.com/crowsonkb/k-diffusion.git
+def stratified_uniform(shape, group=0, groups=1, dtype=None, device=None):
+ """Draws stratified samples from a uniform distribution."""
+ if groups <= 0:
+ raise ValueError(f"groups must be positive, got {groups}")
+ if group < 0 or group >= groups:
+ raise ValueError(f"group must be in [0, {groups})")
+ n = shape[-1] * groups
+ offsets = torch.arange(group, n, groups, dtype=dtype, device=device)
+ u = torch.rand(shape, dtype=dtype, device=device)
+ return (offsets + u) / n
+
+
+def rand_cosine_interpolated(shape, image_d, noise_d_low, noise_d_high, sigma_data=1., min_value=1e-3, max_value=1e3,
+ device='cpu', dtype=torch.float32):
+ """Draws samples from an interpolated cosine timestep distribution (from simple diffusion)."""
+
+ def logsnr_schedule_cosine(t, logsnr_min, logsnr_max):
+ t_min = math.atan(math.exp(-0.5 * logsnr_max))
+ t_max = math.atan(math.exp(-0.5 * logsnr_min))
+ return -2 * torch.log(torch.tan(t_min + t * (t_max - t_min)))
+
+ def logsnr_schedule_cosine_shifted(t, image_d, noise_d, logsnr_min, logsnr_max):
+ shift = 2 * math.log(noise_d / image_d)
+ return logsnr_schedule_cosine(t, logsnr_min - shift, logsnr_max - shift) + shift
+
+ def logsnr_schedule_cosine_interpolated(t, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max):
+ logsnr_low = logsnr_schedule_cosine_shifted(
+ t, image_d, noise_d_low, logsnr_min, logsnr_max)
+ logsnr_high = logsnr_schedule_cosine_shifted(
+ t, image_d, noise_d_high, logsnr_min, logsnr_max)
+ return torch.lerp(logsnr_low, logsnr_high, t)
+
+ logsnr_min = -2 * math.log(min_value / sigma_data)
+ logsnr_max = -2 * math.log(max_value / sigma_data)
+ u = stratified_uniform(
+ shape, group=0, groups=1, dtype=dtype, device=device
+ )
+ logsnr = logsnr_schedule_cosine_interpolated(
+ u, image_d, noise_d_low, noise_d_high, logsnr_min, logsnr_max)
+ return torch.exp(-logsnr / 2) * sigma_data
+
+
+def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
+ """Draws samples from an lognormal distribution."""
+ u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7
+ return torch.distributions.Normal(loc, scale).icdf(u).exp()
+
+
+min_value = 0.002
+max_value = 700
+image_d = 64
+noise_d_low = 32
+noise_d_high = 64
+sigma_data = 0.5
+
+
+def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = _gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(
+ input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+def _compute_padding(kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+def _filter2d(input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(
+ device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = _compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(
+ input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+def _gaussian(window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device,
+ dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+def _gaussian_blur2d(input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = _filter2d(input, kernel_x[..., None, :])
+ out = _filter2d(out_x, kernel_y[..., None])
+
+ return out
+
+
+def export_to_video(video_frames, output_video_path, fps):
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ h, w, _ = video_frames[0].shape
+ video_writer = cv2.VideoWriter(
+ output_video_path, fourcc, fps=fps, frameSize=(w, h))
+ for i in range(len(video_frames)):
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+
+
+def export_to_gif(frames, output_gif_path, fps):
+ """
+ Export a list of frames to a GIF.
+
+ Args:
+ - frames (list): List of frames (as numpy arrays or PIL Image objects).
+ - output_gif_path (str): Path to save the output GIF.
+ - duration_ms (int): Duration of each frame in milliseconds.
+
+ """
+ # Convert numpy arrays to PIL Images if needed
+ pil_frames = [Image.fromarray(frame) if isinstance(
+ frame, np.ndarray) else frame for frame in frames]
+
+ pil_frames[0].save(output_gif_path.replace('.mp4', '.gif'),
+ format='GIF',
+ append_images=pil_frames[1:],
+ save_all=True,
+ duration=125,
+ loop=0)
+
+
+def tensor_to_vae_latent(t, vae, scale=True):
+ t = t.to(vae.dtype)
+ if len(t.shape) == 5:
+ video_length = t.shape[1]
+
+ t = rearrange(t, "b f c h w -> (b f) c h w")
+ latents = vae.encode(t).latent_dist.sample()
+ latents = rearrange(latents, "(b f) c h w -> b f c h w", f=video_length)
+ elif len(t.shape) == 4:
+ latents = vae.encode(t).latent_dist.sample()
+ if scale:
+ latents = latents * vae.config.scaling_factor
+ return latents
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Script to train Stable Diffusion XL for InstructPix2Pix."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+
+ parser.add_argument(
+ "--num_frames",
+ type=int,
+ default=14,
+ )
+ parser.add_argument(
+ "--dataset_type",
+ type=str,
+ default='ubc',
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=1,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=500,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the text/image prompt"
+ " multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="./outputs",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=None, help="A seed for reproducible training."
+ )
+ parser.add_argument(
+ "--per_gpu_batch_size",
+ type=int,
+ default=1,
+ help="Batch size (per device) for the training dataloader.",
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps",
+ type=int,
+ default=500,
+ help="Number of steps for the warmup in the lr scheduler.",
+ )
+ parser.add_argument(
+ "--conditioning_dropout_prob",
+ type=float,
+ default=0.1,
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--use_ema", action="store_true", help="Whether to use EMA model."
+ )
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=8,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--adam_beta1",
+ type=float,
+ default=0.9,
+ help="The beta1 parameter for the Adam optimizer.",
+ )
+ parser.add_argument(
+ "--adam_beta2",
+ type=float,
+ default=0.999,
+ help="The beta2 parameter for the Adam optimizer.",
+ )
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
+ )
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer",
+ )
+ parser.add_argument(
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the model to the Hub.",
+ )
+ parser.add_argument(
+ "--hub_token",
+ type=str,
+ default=None,
+ help="The token to use to push to the Model Hub.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--music_encoder_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained music encoder (.pth)"
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=-1,
+ help="For distributed training: local_rank",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=1,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention",
+ action="store_true",
+ help="Whether or not to use xformers.",
+ )
+ parser.add_argument(
+ "--log_trainable_parameters",
+ action="store_true",
+ help="Whether to write the trainable parameters.",
+ )
+ parser.add_argument(
+ "--pretrain_unet",
+ type=str,
+ default=None,
+ help="use weight for unet block",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=128,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--csv_path",
+ type=str,
+ default=None,
+ help=(
+ "path to the dataset csv"
+ ),
+ )
+ parser.add_argument(
+ "--video_folder",
+ type=str,
+ default=None,
+ help=(
+ "path to the video folder"
+ ),
+ )
+ parser.add_argument(
+ "--condition_folder",
+ type=str,
+ default=None,
+ help=(
+ "path to the depth folder"
+ ),
+ )
+ parser.add_argument(
+ "--motion_folder",
+ type=str,
+ default=None,
+ help=(
+ "path to the depth folder"
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image_folder",
+ type=str,
+ default=None,
+ help=(
+ "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help=(
+ "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_control_folder",
+ type=str,
+ default=None,
+ help=(
+ "the validation control image"
+ ),
+ )
+ parser.add_argument(
+ "--sample_n_frames",
+ type=int,
+ default=14,
+ help=(
+ "the sample_n_frames"
+ ),
+ )
+
+ parser.add_argument(
+ "--ref_augment",
+ action="store_true",
+ help=(
+ "use augment for the reference image"
+ ),
+ )
+ parser.add_argument(
+ "--train_stage",
+ type=int,
+ default=2,
+ help=(
+ "the training stage"
+ ),
+ )
+
+ parser.add_argument(
+ "--posenet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained posenet model",
+ )
+ parser.add_argument(
+ "--face_encoder_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained face encoder model",
+ )
+ parser.add_argument(
+ "--unet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained unet model",
+ )
+
+ parser.add_argument(
+ "--data_root_path",
+ type=str,
+ default=None,
+ help="Path to the data root path",
+ )
+ parser.add_argument(
+ "--data_path",
+ type=str,
+ default=None,
+ help="Path to the data path",
+ )
+
+ parser.add_argument(
+ "--finetune_mode",
+ type=bool,
+ default=False,
+ help="Enable or disable the finetune mode (True/False).",
+ )
+ parser.add_argument(
+ "--posenet_model_finetune_path",
+ type=str,
+ default=None,
+ help="Path to the pretrained posenet model",
+ )
+ parser.add_argument(
+ "--face_encoder_finetune_path",
+ type=str,
+ default=None,
+ help="Path to the pretrained face encoder",
+ )
+ parser.add_argument(
+ "--unet_model_finetune_path",
+ type=str,
+ default=None,
+ help="Path to the pretrained unet model",
+ )
+
+ parser.add_argument(
+ "--dataset_width",
+ type=int,
+ default=512,
+ help="video dataset width",
+ )
+ parser.add_argument(
+ "--dataset_height",
+ type=int,
+ default=512,
+ help="video dataset height",
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def download_image(url):
+ original_image = (
+ lambda image_url_or_path: load_image(image_url_or_path)
+ if urlparse(image_url_or_path).scheme
+ else PIL.Image.open(image_url_or_path).convert("RGB")
+ )(url)
+ return original_image
+
+
+# This is for training using deepspeed.
+# Since now the DeepSpeed only supports trainging with only one model
+# So we create a virtual wrapper to contail all the models
+
+class DeepSpeedWrapperModel(nn.Module):
+ def __init__(self, **kwargs):
+ super().__init__()
+ for name, value in kwargs.items():
+ assert isinstance(value, nn.Module)
+ self.register_module(name, value)
+
+
+def main():
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
+ warnings.filterwarnings('ignore', category=FutureWarning)
+ torch.multiprocessing.set_start_method('spawn')
+
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+ accelerator_project_config = ProjectConfiguration(
+ project_dir=args.output_dir, logging_dir=logging_dir)
+ # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ project_config=accelerator_project_config,
+ )
+
+ generator = torch.Generator(
+ device=accelerator.device).manual_seed(23123134)
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError(
+ "Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ print(args.pretrained_model_name_or_path)
+ feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path,
+ subfolder="feature_extractor", revision=args.revision)
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="image_encoder", revision=args.revision
+ )
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant="fp16")
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path if args.pretrain_unet is None else args.pretrain_unet,
+ subfolder="unet",
+ low_cpu_mem_usage=True,
+ variant="fp16"
+ )
+ pose_net = PoseNet(noise_latent_channels=unet.config.block_out_channels[0])
+ face_encoder = FusionFaceId(
+ cross_attention_dim=1024,
+ id_embeddings_dim=512,
+ clip_embeddings_dim=1024,
+ num_tokens=4, )
+ face_model = FaceModel()
+ music_encoder = MusicEncoder()
+ # init adapter modules
+ lora_rank = 128
+ attn_procs = {}
+ unet_svd = unet.state_dict()
+
+ for name in unet.attn_processors.keys():
+ if "transformer_blocks" in name and "temporal_transformer_blocks" not in name:
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ # print(f"This is AnimationAttnProcessor: {name}")
+ attn_procs[name] = AnimationAttnProcessor(hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim, rank=lora_rank)
+ else:
+ # print(f"This is AnimationIDAttnNormalizedProcessor: {name}")
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "id_to_k.weight": unet_svd[layer_name + ".to_k.weight"],
+ "id_to_v.weight": unet_svd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = AnimationIDAttnNormalizedProcessor(hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=lora_rank)
+ attn_procs[name].load_state_dict(weights, strict=False)
+ elif "temporal_transformer_blocks" in name:
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ attn_procs[name] = XFormersAttnProcessor()
+ else:
+ attn_procs[name] = XFormersAttnProcessor()
+ unet.set_attn_processor(attn_procs)
+
+ # triggering the finetune mode
+ if args.finetune_mode is True and args.posenet_model_finetune_path is not None and args.face_encoder_finetune_path is not None and args.unet_model_finetune_path is not None:
+ print("Loading existing posenet weights, face_encoder weights and unet weights.")
+ if args.posenet_model_finetune_path.endswith(".pth"):
+ pose_net_state_dict = torch.load(args.posenet_model_finetune_path, map_location="cpu")
+ pose_net.load_state_dict(pose_net_state_dict, strict=True)
+ else:
+ print("posenet weights loading fail")
+ print(1 / 0)
+ if args.face_encoder_finetune_path.endswith(".pth"):
+ face_encoder_state_dict = torch.load(args.face_encoder_finetune_path, map_location="cpu")
+ face_encoder.load_state_dict(face_encoder_state_dict, strict=True)
+ else:
+ print("face_encoder weights loading fail")
+ print(1 / 0)
+ if args.unet_model_finetune_path.endswith(".pth"):
+ unet_state_dict = torch.load(args.unet_model_finetune_path, map_location="cpu")
+ unet.load_state_dict(unet_state_dict, strict=True)
+ else:
+ print("unet weights loading fail")
+ print(1 / 0)
+
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ pose_net.requires_grad_(False)
+ face_encoder.requires_grad_(False)
+
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if args.use_ema:
+ ema_unet = EMAModel(unet.parameters(
+ ), model_cls=UNetSpatioTemporalConditionModel, model_config=unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps *
+ args.per_gpu_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ # if accelerator.distributed_type == DistributedType.DEEPSPEED:
+ # ds_wrapper = DeepSpeedWrapperModel(
+ # unet=unet,
+ # controlnext=controlnext
+ # )
+ # unet = ds_wrapper.unet
+ # controlnext = ds_wrapper.controlnext
+
+ pose_net.requires_grad_(True)
+ face_encoder.requires_grad_(True)
+
+ parameters_list = []
+
+ for para in music_encoder.parameters():
+ para.requires_grad = True
+ parameters_list.append({"params": para, "lr": args.learning_rate})
+
+ for name, para in face_encoder.named_parameters():
+ para.requires_grad = True
+ parameters_list.append({"params": para, "lr": args.learning_rate})
+
+ """
+ For more details, please refer to: https://github.com/dvlab-research/ControlNeXt/issues/14#issuecomment-2290450333
+ This is the selective parameters part.
+ As presented in our paper, we only select a small subset of parameters, which is fully adapted to the SD1.5 and SDXL backbones. By training fewer than 100 million parameters, we still achieve excellent performance. But this is is not suitable for the SD3 and SVD training. This is because, after SDXL, Stability faced significant legal risks due to the generation of highly realistic human images. After that, they stopped refining their models on human-related data, such as SVD and SD3, to avoid potential risks.
+ To achieve optimal performance, it's necessary to first continue training SVD and SD3 on human-related data to develop a robust backbone before fine-tuning. Of course, you can also combine the continual pretraining and finetuning. So you can find that we direct provide the full SVD parameters.
+ We have experimented with two approaches: 1.Directly training the model from scratch on human dancing data. 2. Continual training using a pre-trained human generation backbone, followed by fine-tuning a selective small subset of parameters. Interestingly, we observed no significant difference in performance between these two methods.
+ """
+
+ for name, para in unet.named_parameters():
+ if "attentions" in name:
+ para.requires_grad = True
+ parameters_list.append({"params": para})
+ else:
+ para.requires_grad = False
+
+ optimizer = optimizer_cls(
+ parameters_list,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # check para
+ if accelerator.is_main_process and args.log_trainable_parameters:
+ rec_txt1 = open('rec_para.txt', 'w')
+ rec_txt2 = open('rec_para_train.txt', 'w')
+ for name, para in unet.named_parameters():
+ if para.requires_grad is False:
+ rec_txt1.write(f'{name}\n')
+ else:
+ rec_txt2.write(f'{name}\n')
+ rec_txt1.close()
+ rec_txt2.close()
+ # DataLoaders creation:
+ args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
+
+ root_path = args.data_root_path
+ txt_path = args.data_path
+ train_dataset = LargeScaleMusicVideos(
+ root_path=root_path,
+ txt_path=txt_path,
+ width=args.dataset_width,
+ height=args.dataset_height,
+ n_sample_frames=args.sample_n_frames,
+ sample_frame_rate=4,
+ app=face_model.app,
+ handler_ante=face_model.handler_ante,
+ face_helper=face_model.face_helper
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.per_gpu_batch_size,
+ num_workers=args.num_workers,
+ shuffle=True,
+ collate_fn=safe_collate,
+ drop_last=True, # DDP์์ ๋ง์ง๋ง ๋ฐฐ์น ๋ถ์ผ์น ๋ฐฉ์ง
+
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ unet, music_encoder, face_encoder, optimizer, lr_scheduler, train_dataloader = accelerator.prepare(
+ unet, music_encoder, face_encoder, optimizer, lr_scheduler, train_dataloader
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(
+ args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("StableAnimator", config=vars(args))
+
+ # Train!
+ total_batch_size = args.per_gpu_batch_size * \
+ accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(
+ f" Instantaneous batch size per device = {args.per_gpu_batch_size}")
+ logger.info(
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(
+ f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ def encode_image(pixel_values):
+ pixel_values = _resize_with_antialiasing(pixel_values, (224, 224))
+ pixel_values = (pixel_values + 1.0) / 2.0
+
+ pixel_values = pixel_values.to(torch.float32)
+ # Normalize the image with for CLIP input
+ pixel_values = feature_extractor(
+ images=pixel_values,
+ do_normalize=True,
+ do_center_crop=False,
+ do_resize=False,
+ do_rescale=False,
+ return_tensors="pt",
+ ).pixel_values
+
+ pixel_values = pixel_values.to(
+ device=accelerator.device, dtype=image_encoder.dtype)
+ image_embeddings = image_encoder(pixel_values).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+ return image_embeddings
+
+ def _get_add_time_ids(
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ unet=None,
+ device=None
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=device)
+ add_time_ids = add_time_ids.repeat(batch_size, 1)
+ return add_time_ids
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+
+ music_state = torch.load("/root/musicenc_final.pt", map_location="cpu")["model_state"]
+
+ # ์ด๋ฏธ prepare() ๋ ๋ชจ๋ธ์ด๋ผ๋ฉด unwrap ํ ๋ก๋
+ accelerator.unwrap_model(music_encoder).load_state_dict(music_state, strict=True)
+
+ # ํน์๋ฅผ ์ํด ๋๊ธฐํ
+ accelerator.wait_for_everyone()
+ args.resume_from_checkpoint = None
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ resume_global_step = global_step * args.gradient_accumulation_steps
+ first_epoch = global_step // num_update_steps_per_epoch
+ resume_step = resume_global_step % (
+ num_update_steps_per_epoch * args.gradient_accumulation_steps)
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(global_step, args.max_train_steps),
+ disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+
+
+ # ๋ณ๊ฒฝ๋ unet๊ณผ music_encoder ์ถ๊ฐํด์ train ์งํ
+ for epoch in range(first_epoch, args.num_train_epochs):
+
+ # pose_net.train()
+ face_encoder.train()
+ music_encoder.train()
+ unet.train()
+ train_loss = 0.0
+
+ for step, batch in enumerate(train_dataloader):
+ if batch is None:
+ if accelerator.is_local_main_process:
+ print(f"[WARN] Skip batch at step {step} (all samples invalid)")
+ continue
+
+ # Skip steps until we reach the resumed step
+ if step == 0:
+ print("="*50)
+ print("DEBUG: First Batch Information")
+ print("="*50)
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ print(f"{key}: shape={value.shape}, dtype={value.dtype}, "
+ f"min={value.min():.3f}, max={value.max():.3f}")
+ else:
+ print(f"{key}: {type(value)}")
+
+ # ์์
ํน์ง ํ์ธ
+ if "music_fea" in batch:
+ music = batch["music_fea"]
+ print(f"\nMusic features stats:")
+ print(f" Mean: {music.mean():.3f}, Std: {music.std():.3f}")
+ print(f" All zeros? {(music == 0).all()}")
+ print(f" All same value? {music.std() < 1e-6}")
+
+
+ # if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
+ # if step % args.gradient_accumulation_steps == 0:
+ # progress_bar.update(1)
+ # continue
+
+ with accelerator.accumulate(face_encoder, music_encoder, unet): # pose_net
+ with accelerator.autocast():
+ pixel_values = batch["pixel_values"].to(weight_dtype).to(
+ accelerator.device, non_blocking=True
+ )
+ conditional_pixel_values = batch["reference_image"].to(weight_dtype).to(
+ accelerator.device, non_blocking=True
+ )
+
+ latents = tensor_to_vae_latent(pixel_values, vae).to(dtype=weight_dtype)
+
+ # Get the text embedding for conditioning.
+ encoder_hidden_states = encode_image(conditional_pixel_values).to(dtype=weight_dtype)
+ image_embed = encoder_hidden_states.clone()
+
+ train_noise_aug = 0.02
+ conditional_pixel_values = conditional_pixel_values + train_noise_aug * \
+ randn_tensor(conditional_pixel_values.shape, generator=generator,
+ device=conditional_pixel_values.device,
+ dtype=conditional_pixel_values.dtype)
+ conditional_latents = tensor_to_vae_latent(conditional_pixel_values, vae, scale=False)
+
+ # Samle noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ sigmas = rand_cosine_interpolated(shape=[bsz, ], image_d=image_d, noise_d_low=noise_d_low,
+ noise_d_high=noise_d_high, sigma_data=sigma_data,
+ min_value=min_value, max_value=max_value).to(latents.device,
+ dtype=weight_dtype)
+
+ # sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents)
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ sigmas_reshaped = sigmas.clone()
+ while len(sigmas_reshaped.shape) < len(latents.shape):
+ sigmas_reshaped = sigmas_reshaped.unsqueeze(-1)
+
+ noisy_latents = latents + noise * sigmas_reshaped
+
+ timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(latents.device,
+ dtype=weight_dtype)
+
+ inp_noisy_latents = noisy_latents / ((sigmas_reshaped ** 2 + 1) ** 0.5)
+
+ added_time_ids = _get_add_time_ids(
+ fps=6,
+ motion_bucket_id=127.0,
+ noise_aug_strength=train_noise_aug, # noise_aug_strength == 0.0
+ dtype=encoder_hidden_states.dtype,
+ batch_size=bsz,
+ unet=unet,
+ device=latents.device
+ )
+
+ added_time_ids = added_time_ids.to(latents.device)
+
+ # Conditioning dropout to support classifier-free guidance during inference. For more details
+ # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(
+ bsz, device=latents.device, generator=generator)
+ # Sample masks for the edit prompts.
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ # Final text conditioning.
+ null_conditioning = torch.zeros_like(encoder_hidden_states)
+ encoder_hidden_states = torch.where(
+ prompt_mask, null_conditioning, encoder_hidden_states)
+
+ # Sample masks for the original images.
+ image_mask_dtype = conditional_latents.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(
+ image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ # Final image conditioning.
+ conditional_latents = image_mask * conditional_latents
+
+ # Concatenate the `conditional_latents` with the `noisy_latents`.
+ conditional_latents = conditional_latents.unsqueeze(
+ 1).repeat(1, noisy_latents.shape[1], 1, 1, 1)
+
+ #pose_pixels = batch["pose_pixels"].to(
+ # dtype=weight_dtype, device=accelerator.device, non_blocking=True
+ #)
+
+ faceid_embeds = batch.get("faceid_embeds")
+ if faceid_embeds is not None and not isinstance(faceid_embeds, type(None)):
+ faceid_embeds = faceid_embeds.to(
+ dtype=weight_dtype, device=accelerator.device, non_blocking=True
+ )
+ else:
+ faceid_embeds = None
+
+ tgt_face_masks = batch["tgt_face_masks"].to(
+ dtype=weight_dtype, device=accelerator.device, non_blocking=True
+ )
+
+
+ music_embeds = batch["music_fea"].to(
+ dtype=weight_dtype, device=accelerator.device, non_blocking=True
+ )
+ music_latents = music_encoder(music_embeds)
+ # pose_latents = pose_net(pose_pixels)
+
+ # print("This is faceid_latents calculation")
+ # print(faceid_embeds.size()) # [1, 512]
+ # print(image_embed.size()) # [1, 1, 1024]
+
+ faceid_latents = face_encoder(faceid_embeds, image_embed)
+
+ inp_noisy_latents = torch.cat(
+ [inp_noisy_latents, conditional_latents], dim=2)
+ target = latents
+
+ # print(f"the size of encoder_hidden_states: {encoder_hidden_states.size()}") # [1, 1, 1024]
+ # print(f"the size of face latents: {faceid_latents.size()}") # [1, 4, 1024]
+ # print(f"the size of music latents: {music_latents.size()}") # [1, 4, 1024]
+ encoder_hidden_states = torch.cat([
+ encoder_hidden_states, # (B, 1, 1024) reference
+ faceid_latents # (B, 4, 1024) face (if exists)
+ ], dim=1)
+ encoder_hidden_states = encoder_hidden_states.to(latents.dtype)
+ inp_noisy_latents = inp_noisy_latents.to(latents.dtype)
+ #pose_latents = pose_latents.to(latents.dtype)
+
+ # Predict the noise residual
+ model_pred = unet(
+ inp_noisy_latents,
+ timesteps,
+ encoder_hidden_states, # Music ํฌํจ๋จ
+ added_time_ids=added_time_ids,
+ pose_latents=music_latents,
+ ).sample
+
+ sigmas = sigmas_reshaped
+ # Denoise the latents
+ c_out = -sigmas / ((sigmas ** 2 + 1) ** 0.5)
+ c_skip = 1 / (sigmas ** 2 + 1)
+ denoised_latents = model_pred * c_out + c_skip * noisy_latents
+ weighing = (1 + sigmas ** 2) * (sigmas ** -2.0)
+
+ if tgt_face_masks.shape[-2:] != target.shape[-2:]:
+ b, f = tgt_face_masks.shape[0], tgt_face_masks.shape[1]
+ tgt_face_masks_flat = tgt_face_masks.reshape(b * f, *tgt_face_masks.shape[2:])
+ tgt_face_masks_flat = F.interpolate(
+ tgt_face_masks_flat,
+ size=target.shape[-2:],
+ mode='bilinear',
+ align_corners=False
+ )
+ tgt_face_masks = tgt_face_masks_flat.reshape(b, f, *tgt_face_masks_flat.shape[1:])
+
+ # MSE loss
+ loss = torch.mean(
+ (weighing.float() * (denoised_latents.float() -
+ target.float()) ** 2 * (1 + tgt_face_masks)).reshape(target.shape[0], -1),
+ dim=1,
+ )
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(
+ loss.repeat(args.per_gpu_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ # if accelerator.sync_gradients:
+ # accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ with torch.cuda.device(latents.device):
+ torch.cuda.empty_cache()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ # save checkpoints!
+ # if global_step % args.checkpointing_steps == 0 and (accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED):
+ if global_step % args.checkpointing_steps == 0 and accelerator.is_main_process:
+
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None and accelerator.is_main_process:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [
+ d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(
+ checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(
+ checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(
+ f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(
+ args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(
+ args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ unwrap_unet = accelerator.unwrap_model(unet)
+
+ unwrap_face_encoder = accelerator.unwrap_model(face_encoder)
+ unwrap_unet_state_dict = unwrap_unet.state_dict()
+ unwrap_music_encoder = accelerator.unwrap_model(music_encoder)
+ torch.save(unwrap_unet_state_dict,
+ os.path.join(args.output_dir, f"checkpoint-{global_step}", f"unet-{global_step}.pth"))
+
+ unwrap_face_encoder_state_dict = unwrap_face_encoder.state_dict()
+ torch.save(unwrap_face_encoder_state_dict,
+ os.path.join(args.output_dir, f"checkpoint-{global_step}",
+ f"face_encoder-{global_step}.pth"))
+ torch.save(unwrap_music_encoder.state_dict(),
+ os.path.join(args.output_dir, f"checkpoint-{global_step}",
+ f"music_encoder-{global_step}.pth"))
+ logger.info(f"Saved state to {save_path}")
+
+ if accelerator.is_main_process:
+ # sample images!
+ if global_step % args.validation_steps == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} videos."
+ )
+ # create pipeline
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ with torch.cuda.device(latents.device):
+ torch.cuda.empty_cache()
+
+ if step == 0 and epoch == 0:
+ print("DEBUG: Batch Data Shapes")
+ for key, value in batch.items():
+ print(f"{key}: {value.shape}")
+ if step % 10 == 0:
+ print(f"Step {step}, Loss: {loss.item():.4f}")
+
+ logs = {"step_loss": loss.detach().item(
+ ), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # save checkpoints!
+ # if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
+ if accelerator.is_main_process:
+ save_path = os.path.join(
+ args.output_dir, f"checkpoint-last")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+
+if __name__ == "__main__":
+ main()