diff --git a/README.md b/README.md index 09b6353d..7dc3961e 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,43 @@ Here we demonstrate several best results we found in our experiments.
Model:FilmVelvia
+### Longer generations +You can also generate longer animations by using overlapping sliding windows. +``` +python -m scripts.animate --config configs/prompts/{your_config}.yaml --L 64 --context_length 16 +``` +##### Sliding window related parameters: + +```L``` - the length of the generated animation. + +```context_length``` - the length of the sliding window (limited by motion modules capacity), default to ```L```. + +```context_overlap``` - how much neighbouring contexts overlap. By default ```context_length``` / 2 + +```context_stride``` - (2^```context_stride```) is a max stride between 2 neighbour frames. By default 0 + +##### Extended this way gallery examples + +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
Model:ToonYou
+ +![]() |
+ ![]() |
+ ![]() |
+ ![]() |
+
Model:Realistic Vision V2.0
+ #### Community Cases Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚. diff --git a/__assets__/animations/model_01_4x/01.gif b/__assets__/animations/model_01_4x/01.gif new file mode 100644 index 00000000..c1bfb487 Binary files /dev/null and b/__assets__/animations/model_01_4x/01.gif differ diff --git a/__assets__/animations/model_01_4x/02.gif b/__assets__/animations/model_01_4x/02.gif new file mode 100644 index 00000000..c891d0a5 Binary files /dev/null and b/__assets__/animations/model_01_4x/02.gif differ diff --git a/__assets__/animations/model_01_4x/03.gif b/__assets__/animations/model_01_4x/03.gif new file mode 100644 index 00000000..87008443 Binary files /dev/null and b/__assets__/animations/model_01_4x/03.gif differ diff --git a/__assets__/animations/model_01_4x/04.gif b/__assets__/animations/model_01_4x/04.gif new file mode 100644 index 00000000..4e03ca5e Binary files /dev/null and b/__assets__/animations/model_01_4x/04.gif differ diff --git a/__assets__/animations/model_03_4x/01.gif b/__assets__/animations/model_03_4x/01.gif new file mode 100644 index 00000000..617b5433 Binary files /dev/null and b/__assets__/animations/model_03_4x/01.gif differ diff --git a/__assets__/animations/model_03_4x/02.gif b/__assets__/animations/model_03_4x/02.gif new file mode 100644 index 00000000..e0a61af8 Binary files /dev/null and b/__assets__/animations/model_03_4x/02.gif differ diff --git a/__assets__/animations/model_03_4x/03.gif b/__assets__/animations/model_03_4x/03.gif new file mode 100644 index 00000000..08055edc Binary files /dev/null and b/__assets__/animations/model_03_4x/03.gif differ diff --git a/__assets__/animations/model_03_4x/04.gif b/__assets__/animations/model_03_4x/04.gif new file mode 100644 index 00000000..b590c66a Binary files /dev/null and b/__assets__/animations/model_03_4x/04.gif differ diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index 58f22d16..36263d6a 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -1,11 +1,13 @@ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py import inspect +import os from typing import Callable, List, Optional, Union from dataclasses import dataclass import numpy as np import torch +from torch import nn from tqdm import tqdm from diffusers.utils import is_accelerate_available @@ -29,6 +31,8 @@ from ..models.unet import UNet3DConditionModel +from ..utils import overlap_policy +from ..utils.path import get_absolute_path logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,6 +59,7 @@ def __init__( EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ], + scan_inversions: bool = True, ): super().__init__() @@ -114,6 +119,36 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.embeddings_dir = get_absolute_path('models', 'embeddings') + self.embeddings_dict = {} + self.default_tokens = len(self.tokenizer) + self.scan_inversions = scan_inversions + + def update_embeddings(self): + if not self.scan_inversions: + return + names = [p for p in os.listdir(self.embeddings_dir) if p.endswith('.pt')] + weight = self.text_encoder.text_model.embeddings.token_embedding.weight + added_embeddings = [] + for name in names: + embedding_path = os.path.join(self.embeddings_dir, name) + embedding = torch.load(embedding_path) + key = os.path.splitext(name)[0] + if key in self.tokenizer.encoder: + idx = self.tokenizer.encoder[key] + else: + idx = len(self.tokenizer) + self.tokenizer.add_tokens([key]) + embedding = embedding['string_to_param']['*'] + if idx not in self.embeddings_dict: + added_embeddings.append(name) + self.embeddings_dict[idx] = torch.arange(weight.shape[0], weight.shape[0] + embedding.shape[0]) + weight = torch.cat([weight, embedding.to(weight.device, weight.dtype)], dim=0) + self.tokenizer.add_tokens([key]) + if added_embeddings: + self.text_encoder.text_model.embeddings.token_embedding = nn.Embedding( + weight.shape[0], weight.shape[1], _weight=weight) + logger.info(f'Added {len(added_embeddings)} embeddings: {added_embeddings}') def enable_vae_slicing(self): self.vae.enable_slicing() @@ -147,9 +182,32 @@ def _execution_device(self): return torch.device(module._hf_hook.execution_device) return self.device + def insert_inversions(self, ids, attention_mask): + larger = ids >= self.default_tokens + for idx in reversed(torch.where(larger)[1]): + ids = torch.cat([ + ids[:, :idx], + self.embeddings_dict[ids[:, idx].item()].unsqueeze(0), + ids[:, idx + 1:], + ], 1) + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask[:, :idx], + torch.ones(1, 1, dtype=attention_mask.dtype, device=attention_mask.device), + attention_mask[:, idx + 1:], + ], 1) + if ids.shape[1] > self.tokenizer.model_max_length: + logger.warning(f"After inserting inversions, the sequence length is larger than the max length. Cutting off" + f" {ids.shape[1] - self.tokenizer.model_max_length} tokens.") + ids = torch.cat([ids[:, :self.tokenizer.model_max_length - 1], ids[:, -1:]], 1) + if attention_mask is not None: + attention_mask = attention_mask[:, :self.tokenizer.model_max_length] + return ids, attention_mask + def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 + self.update_embeddings() text_inputs = self.tokenizer( prompt, padding="max_length", @@ -172,6 +230,7 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr else: attention_mask = None + text_input_ids, attention_mask = self.insert_inversions(text_input_ids, attention_mask) text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, @@ -218,8 +277,10 @@ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_fr else: attention_mask = None + uncond_input_ids = uncond_input.input_ids + uncond_input_ids, attention_mask = self.insert_inversions(uncond_input_ids, attention_mask) uncond_embeddings = self.text_encoder( - uncond_input.input_ids.to(device), + uncond_input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = uncond_embeddings[0] @@ -242,8 +303,9 @@ def decode_latents(self, latents): latents = rearrange(latents, "b c f h w -> (b f) c h w") # video = self.vae.decode(latents).sample video = [] + device = self._execution_device for frame_idx in tqdm(range(latents.shape[0])): - video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) + video.append(self.vae.decode(latents[frame_idx:frame_idx+1].to(device)).sample) video = torch.cat(video) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) @@ -317,6 +379,9 @@ def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], + temporal_context: Optional[int] = None, + strides: int = 3, + overlap: int = 4, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -330,6 +395,8 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, + seq_policy=overlap_policy.uniform, + fp16=False, **kwargs, ): # Default height and width to unet @@ -348,6 +415,7 @@ def __call__( batch_size = len(prompt) device = self._execution_device + cpu = torch.device('cpu') # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @@ -356,7 +424,7 @@ def __call__( # Encode input prompt prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size if negative_prompt is not None: - negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size + negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) @@ -373,8 +441,8 @@ def __call__( video_length, height, width, - text_embeddings.dtype, - device, + torch.float32, + cpu, # using cpu to store latents allows generated frame amount not to be limited by vram but by ram generator, latents, ) @@ -382,28 +450,33 @@ def __call__( # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - + total = sum( + len(list(seq_policy(i, num_inference_steps, latents.shape[2], temporal_context, strides, overlap))) + for i in range(len(timesteps)) + ) # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=total) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) - # noise_pred = [] - # import pdb - # pdb.set_trace() - # for batch_idx in range(latent_model_input.shape[0]): - # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype) - # noise_pred.append(noise_pred_single) - # noise_pred = torch.cat(noise_pred) + noise_pred = torch.zeros((latents.shape[0] * (2 if do_classifier_free_guidance else 1), + *latents.shape[1:]), device=latents.device, dtype=latents_dtype) + counter = torch.zeros((1, 1, latents.shape[2], 1, 1), device=latents.device, dtype=latents_dtype) + for seq in seq_policy(i, num_inference_steps, latents.shape[2], temporal_context, strides, overlap): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents[:, :, seq].to(device)\ + .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + with torch.autocast('cuda', enabled=fp16, dtype=torch.float16): + pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings) + noise_pred[:, :, seq] += pred.sample.to(dtype=latents_dtype, device=cpu) + counter[:, :, seq] += 1 + progress_bar.update() # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 @@ -411,7 +484,6 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/animatediff/utils/overlap_policy.py b/animatediff/utils/overlap_policy.py new file mode 100644 index 00000000..98dee361 --- /dev/null +++ b/animatediff/utils/overlap_policy.py @@ -0,0 +1,20 @@ +import numpy as np + + +def ordered_halving(i): + return int('{:064b}'.format(i)[::-1], 2) / (1 << 64) + + +def uniform(step, steps, n, context_size, strides, overlap, closed_loop=True): + if n <= context_size: + yield list(range(n)) + return + strides = min(strides, int(np.ceil(np.log2(n / context_size))) + 1) + for stride in 1 << np.arange(strides): + pad = int(round(n * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * stride) + pad, + n + pad + (0 if closed_loop else -overlap), + (context_size * stride - overlap) + ): + yield [e % n for e in range(j, j + context_size * stride, stride)] diff --git a/animatediff/utils/path.py b/animatediff/utils/path.py new file mode 100644 index 00000000..bb79b57c --- /dev/null +++ b/animatediff/utils/path.py @@ -0,0 +1,16 @@ +import os + +project_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def get_absolute_path(*relative): + if relative[0].startswith('/'): + return os.path.join(*relative) # absolute path + return os.path.join(project_path, *relative) + + +if __name__ == '__main__': + print(get_absolute_path('test')) + print(get_absolute_path('/test')) + print(get_absolute_path('test', 'test')) + print(get_absolute_path('/test', 'test')) diff --git a/models/embeddings/Place Textual Inversion embeddings here.txt b/models/embeddings/Place Textual Inversion embeddings here.txt new file mode 100644 index 00000000..e69de29b diff --git a/scripts/animate.py b/scripts/animate.py index 8bb5dd74..f6b8c6cd 100644 --- a/scripts/animate.py +++ b/scripts/animate.py @@ -30,7 +30,12 @@ def main(args): *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) - + + if args.context_length == 0: + args.context_length = args.L + if args.context_overlap == -1: + args.context_overlap = args.context_length // 2 + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" os.makedirs(savedir) @@ -58,6 +63,7 @@ def main(args): pipeline = AnimationPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + scan_inversions=not args.disable_inversions, ).to("cuda") # 1. unet ckpt @@ -130,6 +136,10 @@ def main(args): width = args.W, height = args.H, video_length = args.L, + temporal_context = args.context_length, + strides = args.context_stride + 1, + overlap = args.context_overlap, + fp16 = not args.fp32, ).videos samples.append(sample) @@ -150,7 +160,18 @@ def main(args): parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",) parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml") parser.add_argument("--config", type=str, required=True) - + + parser.add_argument("--fp32", action="store_true") + parser.add_argument("--disable_inversions", action="store_true", + help="do not scan for downloaded textual inversions") + + parser.add_argument("--context_length", type=int, default=0, + help="temporal transformer context length (0 for same as -L)") + parser.add_argument("--context_stride", type=int, default=0, + help="max stride of motion is 2^context_stride") + parser.add_argument("--context_overlap", type=int, default=-1, + help="overlap between chunks of context (-1 for half of context length)") + parser.add_argument("--L", type=int, default=16 ) parser.add_argument("--W", type=int, default=512) parser.add_argument("--H", type=int, default=512)